diff --git a/.gitignore b/.gitignore index 6a52691..e27209c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ -olm .DS_Store bin/ \ No newline at end of file diff --git a/Makefile b/Makefile index 433e275..7e4cdf9 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ docker-build-release: docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push . local: - CGO_ENABLED=0 go build -o olm + CGO_ENABLED=0 go build -o bin/olm build: docker build -t fosrl/olm:latest . diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..787f958 --- /dev/null +++ b/api/api.go @@ -0,0 +1,503 @@ +package api + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" +) + +// ConnectionRequest defines the structure for an incoming connection request +type ConnectionRequest struct { + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` + UserToken string `json:"userToken,omitempty"` + MTU int `json:"mtu,omitempty"` + DNS string `json:"dns,omitempty"` + DNSProxyIP string `json:"dnsProxyIP,omitempty"` + UpstreamDNS []string `json:"upstreamDNS,omitempty"` + InterfaceName string `json:"interfaceName,omitempty"` + Holepunch bool `json:"holepunch,omitempty"` + TlsClientCert string `json:"tlsClientCert,omitempty"` + PingInterval string `json:"pingInterval,omitempty"` + PingTimeout string `json:"pingTimeout,omitempty"` + OrgID string `json:"orgId,omitempty"` +} + +// SwitchOrgRequest defines the structure for switching organizations +type SwitchOrgRequest struct { + OrgID string `json:"orgId"` +} + +// PeerStatus represents the status of a peer connection +type PeerStatus struct { + SiteID int `json:"siteId"` + Name string `json:"name"` + Connected bool `json:"connected"` + RTT time.Duration `json:"rtt"` + LastSeen time.Time `json:"lastSeen"` + Endpoint string `json:"endpoint,omitempty"` + IsRelay bool `json:"isRelay"` + PeerIP string `json:"peerAddress,omitempty"` + HolepunchConnected bool `json:"holepunchConnected"` +} + +// StatusResponse is returned by the status endpoint +type StatusResponse struct { + Connected bool `json:"connected"` + Registered bool `json:"registered"` + Terminated bool `json:"terminated"` + Version string `json:"version,omitempty"` + Agent string `json:"agent,omitempty"` + OrgID string `json:"orgId,omitempty"` + PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` + NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` +} + +// API represents the HTTP server and its state +type API struct { + addr string + socketPath string + listener net.Listener + server *http.Server + onConnect func(ConnectionRequest) error + onSwitchOrg func(SwitchOrgRequest) error + onDisconnect func() error + onExit func() error + statusMu sync.RWMutex + peerStatuses map[int]*PeerStatus + connectedAt time.Time + isConnected bool + isRegistered bool + isTerminated bool + version string + agent string + orgID string +} + +// NewAPI creates a new HTTP server that listens on a TCP address +func NewAPI(addr string) *API { + s := &API{ + addr: addr, + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + +// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe +func NewAPISocket(socketPath string) *API { + s := &API{ + socketPath: socketPath, + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + +// SetHandlers sets the callback functions for handling API requests +func (s *API) SetHandlers( + onConnect func(ConnectionRequest) error, + onSwitchOrg func(SwitchOrgRequest) error, + onDisconnect func() error, + onExit func() error, +) { + s.onConnect = onConnect + s.onSwitchOrg = onSwitchOrg + s.onDisconnect = onDisconnect + s.onExit = onExit +} + +// Start starts the HTTP server +func (s *API) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/connect", s.handleConnect) + mux.HandleFunc("/status", s.handleStatus) + mux.HandleFunc("/switch-org", s.handleSwitchOrg) + mux.HandleFunc("/disconnect", s.handleDisconnect) + mux.HandleFunc("/exit", s.handleExit) + mux.HandleFunc("/health", s.handleHealth) + + s.server = &http.Server{ + Handler: mux, + } + + var err error + if s.socketPath != "" { + // Use platform-specific socket listener + s.listener, err = createSocketListener(s.socketPath) + if err != nil { + return fmt.Errorf("failed to create socket listener: %w", err) + } + logger.Info("Starting HTTP server on socket %s", s.socketPath) + } else { + // Use TCP listener + s.listener, err = net.Listen("tcp", s.addr) + if err != nil { + return fmt.Errorf("failed to create TCP listener: %w", err) + } + logger.Info("Starting HTTP server on %s", s.addr) + } + + go func() { + if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed { + logger.Error("HTTP server error: %v", err) + } + }() + + return nil +} + +// Stop stops the HTTP server +func (s *API) Stop() error { + logger.Info("Stopping api server") + + // Close the server first, which will also close the listener gracefully + if s.server != nil { + s.server.Close() + } + + // Clean up socket file if using Unix socket + if s.socketPath != "" { + cleanupSocket(s.socketPath) + } + + return nil +} + +func (s *API) AddPeerStatus(siteID int, siteName string, connected bool, rtt time.Duration, endpoint string, isRelay bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.Name = siteName + status.Connected = connected + status.RTT = rtt + status.LastSeen = time.Now() + status.Endpoint = endpoint + status.IsRelay = isRelay +} + +// UpdatePeerStatus updates the status of a peer including endpoint and relay info +func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.Connected = connected + status.RTT = rtt + status.LastSeen = time.Now() + status.Endpoint = endpoint + status.IsRelay = isRelay +} + +func (s *API) RemovePeerStatus(siteID int) { // remove the peer from the status map + s.statusMu.Lock() + defer s.statusMu.Unlock() + delete(s.peerStatuses, siteID) +} + +// SetConnectionStatus sets the overall connection status +func (s *API) SetConnectionStatus(isConnected bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + s.isConnected = isConnected + + if isConnected { + s.connectedAt = time.Now() + } else { + // Clear peer statuses when disconnected + s.peerStatuses = make(map[int]*PeerStatus) + } +} + +func (s *API) SetRegistered(registered bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.isRegistered = registered +} + +func (s *API) SetTerminated(terminated bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.isTerminated = terminated +} + +// ClearPeerStatuses clears all peer statuses +func (s *API) ClearPeerStatuses() { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.peerStatuses = make(map[int]*PeerStatus) +} + +// SetVersion sets the olm version +func (s *API) SetVersion(version string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.version = version +} + +// SetAgent sets the olm agent +func (s *API) SetAgent(agent string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.agent = agent +} + +// SetOrgID sets the organization ID +func (s *API) SetOrgID(orgID string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.orgID = orgID +} + +// UpdatePeerRelayStatus updates only the relay status of a peer +func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.Endpoint = endpoint + status.IsRelay = isRelay +} + +// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer +func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.HolepunchConnected = holepunchConnected +} + +// handleConnect handles the /connect endpoint +func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // if we are already connected, reject new connection requests + s.statusMu.RLock() + alreadyConnected := s.isConnected + s.statusMu.RUnlock() + if alreadyConnected { + http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict) + return + } + + var req ConnectionRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.ID == "" || req.Secret == "" || req.Endpoint == "" { + http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest) + return + } + + // Call the connect handler if set + if s.onConnect != nil { + if err := s.onConnect(req); err != nil { + http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError) + return + } + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]string{ + "status": "connection request accepted", + }) +} + +// handleStatus handles the /status endpoint +func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + s.statusMu.RLock() + defer s.statusMu.RUnlock() + + resp := StatusResponse{ + Connected: s.isConnected, + Registered: s.isRegistered, + Terminated: s.isTerminated, + Version: s.version, + Agent: s.agent, + OrgID: s.orgID, + PeerStatuses: s.peerStatuses, + NetworkSettings: network.GetSettings(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// handleHealth handles the /health endpoint +func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": "ok", + }) +} + +// handleExit handles the /exit endpoint +func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + logger.Info("Received exit request via API") + + // Return a success response first + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": "shutdown initiated", + }) + + // Call the exit handler after responding, in a goroutine with a small delay + // to ensure the response is fully sent before shutdown begins + if s.onExit != nil { + go func() { + time.Sleep(100 * time.Millisecond) + if err := s.onExit(); err != nil { + logger.Error("Exit handler failed: %v", err) + } + }() + } +} + +// handleSwitchOrg handles the /switch-org endpoint +func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req SwitchOrgRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.OrgID == "" { + http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest) + return + } + + logger.Info("Received org switch request to orgId: %s", req.OrgID) + + // Call the switch org handler if set + if s.onSwitchOrg != nil { + if err := s.onSwitchOrg(req); err != nil { + http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError) + return + } + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": "org switch request accepted", + }) +} + +// handleDisconnect handles the /disconnect endpoint +func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // if we are already disconnected, reject new disconnect requests + s.statusMu.RLock() + alreadyDisconnected := !s.isConnected + s.statusMu.RUnlock() + if alreadyDisconnected { + http.Error(w, "Not currently connected to a server.", http.StatusConflict) + return + } + + logger.Info("Received disconnect request via API") + + // Call the disconnect handler if set + if s.onDisconnect != nil { + if err := s.onDisconnect(); err != nil { + http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError) + return + } + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": "disconnect initiated", + }) +} + +func (s *API) GetStatus() StatusResponse { + return StatusResponse{ + Connected: s.isConnected, + Registered: s.isRegistered, + Terminated: s.isTerminated, + Version: s.version, + Agent: s.agent, + OrgID: s.orgID, + PeerStatuses: s.peerStatuses, + NetworkSettings: network.GetSettings(), + } +} diff --git a/api/api_unix.go b/api/api_unix.go new file mode 100644 index 0000000..2dab602 --- /dev/null +++ b/api/api_unix.go @@ -0,0 +1,50 @@ +//go:build !windows +// +build !windows + +package api + +import ( + "fmt" + "net" + "os" + "path/filepath" + + "github.com/fosrl/newt/logger" +) + +// createSocketListener creates a Unix domain socket listener +func createSocketListener(socketPath string) (net.Listener, error) { + // Ensure the directory exists + dir := filepath.Dir(socketPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create socket directory: %w", err) + } + + // Remove existing socket file if it exists + if err := os.RemoveAll(socketPath); err != nil { + return nil, fmt.Errorf("failed to remove existing socket: %w", err) + } + + listener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("failed to listen on Unix socket: %w", err) + } + + // Set socket permissions to allow access + if err := os.Chmod(socketPath, 0666); err != nil { + listener.Close() + return nil, fmt.Errorf("failed to set socket permissions: %w", err) + } + + logger.Debug("Created Unix socket at %s", socketPath) + return listener, nil +} + +// cleanupSocket removes the Unix socket file +func cleanupSocket(socketPath string) { + if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove socket file %s: %v", socketPath, err) + } else { + logger.Debug("Removed Unix socket at %s", socketPath) + } +} diff --git a/api/api_windows.go b/api/api_windows.go new file mode 100644 index 0000000..d9ef373 --- /dev/null +++ b/api/api_windows.go @@ -0,0 +1,41 @@ +//go:build windows +// +build windows + +package api + +import ( + "fmt" + "net" + + "github.com/Microsoft/go-winio" + "github.com/fosrl/newt/logger" +) + +// createSocketListener creates a Windows named pipe listener +func createSocketListener(pipePath string) (net.Listener, error) { + // Ensure the pipe path has the correct format + if pipePath[0] != '\\' { + pipePath = `\\.\pipe\` + pipePath + } + + // Create a pipe configuration that allows everyone to write + config := &winio.PipeConfig{ + // Set security descriptor to allow everyone full access + // This SDDL string grants full access to Everyone (WD) and to the current owner (OW) + SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)", + } + + // Create a named pipe listener using go-winio with the configuration + listener, err := winio.ListenPipe(pipePath, config) + if err != nil { + return nil, fmt.Errorf("failed to listen on named pipe: %w", err) + } + + logger.Debug("Created named pipe at %s with write access for everyone", pipePath) + return listener, nil +} + +// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up +func cleanupSocket(pipePath string) { + logger.Debug("Named pipe %s will be automatically cleaned up", pipePath) +} diff --git a/common.go b/common.go deleted file mode 100644 index 63d8ea4..0000000 --- a/common.go +++ /dev/null @@ -1,1145 +0,0 @@ -package main - -import ( - "encoding/base64" - "encoding/hex" - "encoding/json" - "fmt" - "net" - "os/exec" - "regexp" - "runtime" - "strconv" - "strings" - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/peermonitor" - "github.com/fosrl/olm/websocket" - "github.com/vishvananda/netlink" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/exp/rand" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -type WgData struct { - Sites []SiteConfig `json:"sites"` - TunnelIP string `json:"tunnelIP"` -} - -type SiteConfig struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access -} - -type TargetsByType struct { - UDP []string `json:"udp"` - TCP []string `json:"tcp"` -} - -type TargetData struct { - Targets []string `json:"targets"` -} - -type HolePunchMessage struct { - NewtID string `json:"newtId"` -} - -type ExitNode struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` -} - -type HolePunchData struct { - ExitNodes []ExitNode `json:"exitNodes"` -} - -type EncryptedHolePunchMessage struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` -} - -var ( - peerMonitor *peermonitor.PeerMonitor - stopHolepunch chan struct{} - stopRegister func() - stopPing chan struct{} - olmToken string - holePunchRunning bool -) - -const ( - ENV_WG_TUN_FD = "WG_TUN_FD" - ENV_WG_UAPI_FD = "WG_UAPI_FD" - ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" -) - -type fixedPortBind struct { - port uint16 - conn.Bind -} - -// PeerAction represents a request to add, update, or remove a peer -type PeerAction struct { - Action string `json:"action"` // "add", "update", or "remove" - SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information -} - -// UpdatePeerData represents the data needed to update a peer -type UpdatePeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access -} - -// AddPeerData represents the data needed to add a peer -type AddPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access -} - -// RemovePeerData represents the data needed to remove a peer -type RemovePeerData struct { - SiteId int `json:"siteId"` -} - -type RelayPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` -} - -func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - // Ignore the requested port and use our fixed port - return b.Bind.Open(b.port) -} - -func NewFixedPortBind(port uint16) conn.Bind { - return &fixedPortBind{ - port: port, - Bind: conn.NewDefaultBind(), - } -} - -func fixKey(key string) string { - // Remove any whitespace - key = strings.TrimSpace(key) - - // Decode from base64 - decoded, err := base64.StdEncoding.DecodeString(key) - if err != nil { - logger.Fatal("Error decoding base64") - } - - // Convert to hex - return hex.EncodeToString(decoded) -} - -func parseLogLevel(level string) logger.LogLevel { - switch strings.ToUpper(level) { - case "DEBUG": - return logger.DEBUG - case "INFO": - return logger.INFO - case "WARN": - return logger.WARN - case "ERROR": - return logger.ERROR - case "FATAL": - return logger.FATAL - default: - return logger.INFO // default to INFO if invalid level provided - } -} - -func mapToWireGuardLogLevel(level logger.LogLevel) int { - switch level { - case logger.DEBUG: - return device.LogLevelVerbose - // case logger.INFO: - // return device.LogLevel - case logger.WARN: - return device.LogLevelError - case logger.ERROR, logger.FATAL: - return device.LogLevelSilent - default: - return device.LogLevelSilent - } -} - -func resolveDomain(domain string) (string, error) { - // First handle any protocol prefix - domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://") - - // if there are any trailing slashes, remove them - domain = strings.TrimSuffix(domain, "/") - - // Now split host and port - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil -} - -func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error { - if serverPubKey == "" || olmToken == "" { - return nil - } - - payload := struct { - OlmID string `json:"olmId"` - Token string `json:"token"` - }{ - OlmID: olmID, - Token: olmToken, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %v", err) - } - - jsonData, err := json.Marshal(encryptedPayload) - if err != nil { - return fmt.Errorf("failed to marshal encrypted payload: %v", err) - } - - _, err = conn.WriteToUDP(jsonData, remoteAddr) - if err != nil { - return fmt.Errorf("failed to send UDP packet: %v", err) - } - - logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) - - return nil -} - -func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { - // Generate an ephemeral keypair for this message - ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) - } - ephemeralPublicKey := ephemeralPrivateKey.PublicKey() - - // Parse the server's public key - serverPubKey, err := wgtypes.ParseKey(serverPublicKey) - if err != nil { - return nil, fmt.Errorf("failed to parse server public key: %v", err) - } - - // Use X25519 for key exchange (replacing deprecated ScalarMult) - var ephPrivKeyFixed [32]byte - copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) - - // Perform X25519 key exchange - sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) - if err != nil { - return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) - } - - // Create an AEAD cipher using the shared secret - aead, err := chacha20poly1305.New(sharedSecret) - if err != nil { - return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) - } - - // Generate a random nonce - nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %v", err) - } - - // Encrypt the payload - ciphertext := aead.Seal(nil, nonce, payload, nil) - - // Prepare the final encrypted message - encryptedMsg := struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` - }{ - EphemeralPublicKey: ephemeralPublicKey.String(), - Nonce: nonce, - Ciphertext: ciphertext, - } - - return encryptedMsg, nil -} - -func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) { - if len(exitNodes) == 0 { - logger.Warn("No exit nodes provided for hole punching") - return - } - - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() - - logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes)) - defer logger.Info("UDP hole punch goroutine ended for all exit nodes") - - // Create the UDP connection once and reuse it for all exit nodes - localAddr := &net.UDPAddr{ - Port: int(sourcePort), - IP: net.IPv4zero, - } - - conn, err := net.ListenUDP("udp", localAddr) - if err != nil { - logger.Error("Failed to bind UDP socket: %v", err) - return - } - defer conn.Close() - - // Resolve all endpoints upfront - type resolvedExitNode struct { - remoteAddr *net.UDPAddr - publicKey string - endpointName string - } - - var resolvedNodes []resolvedExitNode - for _, exitNode := range exitNodes { - host, err := resolveDomain(exitNode.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) - continue - } - - serverAddr := net.JoinHostPort(host, "21820") - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err) - continue - } - - resolvedNodes = append(resolvedNodes, resolvedExitNode{ - remoteAddr: remoteAddr, - publicKey: exitNode.PublicKey, - endpointName: exitNode.Endpoint, - }) - logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) - } - - if len(resolvedNodes) == 0 { - logger.Error("No exit nodes could be resolved") - return - } - - // Send initial hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil { - logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err) - } - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-stopHolepunch: - logger.Info("Stopping UDP holepunch for all exit nodes") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes") - return - case <-ticker.C: - // Send hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil { - logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err) - } - } - } - } -} - -func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) { - - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() - - logger.Info("Starting UDP hole punch to %s", endpoint) - defer logger.Info("UDP hole punch goroutine ended for %s", endpoint) - - host, err := resolveDomain(endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return - } - - serverAddr := net.JoinHostPort(host, "21820") - - // Create the UDP connection once and reuse it - localAddr := &net.UDPAddr{ - Port: int(sourcePort), - IP: net.IPv4zero, - } - - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address: %v", err) - return - } - - conn, err := net.ListenUDP("udp", localAddr) - if err != nil { - logger.Error("Failed to bind UDP socket: %v", err) - return - } - defer conn.Close() - - // Execute once immediately before starting the loop - if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-stopHolepunch: - logger.Info("Stopping UDP holepunch") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds") - return - case <-ticker.C: - if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - } - } -} - -func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { - if maxPort < minPort { - return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) - } - - // Create a slice of all ports in the range - portRange := make([]uint16, maxPort-minPort+1) - for i := range portRange { - portRange[i] = minPort + uint16(i) - } - - // Fisher-Yates shuffle to randomize the port order - rand.Seed(uint64(time.Now().UnixNano())) - for i := len(portRange) - 1; i > 0; i-- { - j := rand.Intn(i + 1) - portRange[i], portRange[j] = portRange[j], portRange[i] - } - - // Try each port in the randomized order - for _, port := range portRange { - addr := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port), - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - continue // Port is in use or there was an error, try next port - } - _ = conn.SetDeadline(time.Now()) - conn.Close() - return port, nil - } - - return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) -} - -func sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]interface{}{ - "timestamp": time.Now().Unix(), - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err - } - logger.Debug("Sent ping message") - return nil -} - -func keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") - } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } - } - } -} - -// ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { - siteHost, err := resolveDomain(siteConfig.Endpoint) - if err != nil { - return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) - } - - // Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP - allowedIp := strings.Split(siteConfig.ServerIP, "/") - if len(allowedIp) > 1 { - allowedIp[1] = "32" - } else { - allowedIp = append(allowedIp, "32") - } - allowedIpStr := strings.Join(allowedIp, "/") - - // Collect all allowed IPs in a slice - var allowedIPs []string - allowedIPs = append(allowedIPs, allowedIpStr) - - // If we have anything in remoteSubnets, add those as well - if siteConfig.RemoteSubnets != "" { - // Split remote subnets by comma and add each one - remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",") - for _, subnet := range remoteSubnets { - subnet = strings.TrimSpace(subnet) - if subnet != "" { - allowedIPs = append(allowedIPs, subnet) - } - } - } - - // Construct WireGuard config for this peer - var configBuilder strings.Builder - configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String()))) - configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey))) - - // Add each allowed IP separately - for _, allowedIP := range allowedIPs { - configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) - } - - configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=1\n") - - config := configBuilder.String() - logger.Debug("Configuring peer with config: %s", config) - - err = dev.IpcSet(config) - if err != nil { - return fmt.Errorf("failed to configure WireGuard peer: %v", err) - } - - // Set up peer monitoring - if peerMonitor != nil { - monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] - monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port - logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) - - primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - wgConfig := &peermonitor.WireGuardConfig{ - SiteID: siteConfig.SiteId, - PublicKey: fixKey(siteConfig.PublicKey), - ServerIP: strings.Split(siteConfig.ServerIP, "/")[0], - Endpoint: siteConfig.Endpoint, - PrimaryRelay: primaryRelay, - } - - err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig) - if err != nil { - logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) - } else { - logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) - } - } - - return nil -} - -// RemovePeer removes a peer from the WireGuard device -func RemovePeer(dev *device.Device, siteId int, publicKey string) error { - // Construct WireGuard config to remove the peer - var configBuilder strings.Builder - configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey))) - configBuilder.WriteString("remove=true\n") - - config := configBuilder.String() - logger.Debug("Removing peer with config: %s", config) - - err := dev.IpcSet(config) - if err != nil { - return fmt.Errorf("failed to remove WireGuard peer: %v", err) - } - - // Stop monitoring this peer - if peerMonitor != nil { - peerMonitor.RemovePeer(siteId) - logger.Info("Stopped monitoring for site %d", siteId) - } - - return nil -} - -// ConfigureInterface configures a network interface with an IP address and brings it up -func ConfigureInterface(interfaceName string, wgData WgData) error { - var ipAddr string = wgData.TunnelIP - - // Parse the IP address and network - ip, ipNet, err := net.ParseCIDR(ipAddr) - if err != nil { - return fmt.Errorf("invalid IP address: %v", err) - } - - switch runtime.GOOS { - case "linux": - return configureLinux(interfaceName, ip, ipNet) - case "darwin": - return configureDarwin(interfaceName, ip, ipNet) - case "windows": - return configureWindows(interfaceName, ip, ipNet) - default: - return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) - } -} - -func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - logger.Info("Configuring Windows interface: %s", interfaceName) - - // Calculate mask string (e.g., 255.255.255.0) - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - // Set the IP address using netsh - cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", interfaceName), - "source=static", - fmt.Sprintf("addr=%s", ip.String()), - fmt.Sprintf("mask=%s", maskIP.String())) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh command failed: %v, output: %s", err, out) - } - - // Bring up the interface if needed (in Windows, setting the IP usually brings it up) - // But we'll explicitly enable it to be sure - cmd = exec.Command("netsh", "interface", "set", "interface", - interfaceName, - "admin=enable") - - logger.Info("Running command: %v", cmd) - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out) - } - - // delay 2 seconds - time.Sleep(8 * time.Second) - - // Wait for the interface to be up and have the correct IP - err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) - if err != nil { - return fmt.Errorf("interface did not come up within timeout: %v", err) - } - - return nil -} - -// waitForInterfaceUp polls the network interface until it's up or times out -func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { - logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) - deadline := time.Now().Add(timeout) - pollInterval := 500 * time.Millisecond - - for time.Now().Before(deadline) { - // Check if interface exists and is up - iface, err := net.InterfaceByName(interfaceName) - if err == nil { - // Check if interface is up - if iface.Flags&net.FlagUp != 0 { - // Check if it has the expected IP - addrs, err := iface.Addrs() - if err == nil { - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if ok && ipNet.IP.Equal(expectedIP) { - logger.Info("Interface %s is up with correct IP", interfaceName) - return nil // Interface is up with correct IP - } - } - logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) - } - } else { - logger.Info("Interface %s exists but is not up yet", interfaceName) - } - } else { - logger.Info("Interface %s not found yet: %v", interfaceName, err) - } - - // Wait before next check - time.Sleep(pollInterval) - } - - return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) -} - -func WindowsAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "windows" { - return nil - } - - var cmd *exec.Cmd - - // Parse destination to get the IP and subnet - ip, ipNet, err := net.ParseCIDR(destination) - if err != nil { - return fmt.Errorf("invalid destination address: %v", err) - } - - // Calculate the subnet mask - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("route", "add", - ip.String(), - "mask", maskIP.String(), - gateway, - "metric", "1") - } else if interfaceName != "" { - // First, get the interface index - indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces") - output, err := indexCmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to get interface index: %v, output: %s", err, output) - } - - // Parse the output to find the interface index - lines := strings.Split(string(output), "\n") - var ifIndex string - for _, line := range lines { - if strings.Contains(line, interfaceName) { - fields := strings.Fields(line) - if len(fields) > 0 { - ifIndex = fields[0] - break - } - } - } - - if ifIndex == "" { - return fmt.Errorf("could not find index for interface %s", interfaceName) - } - - // Convert to integer to validate - idx, err := strconv.Atoi(ifIndex) - if err != nil { - return fmt.Errorf("invalid interface index: %v", err) - } - - // Route via interface using the index - cmd = exec.Command("route", "add", - ip.String(), - "mask", maskIP.String(), - "0.0.0.0", - "if", strconv.Itoa(idx)) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route command failed: %v, output: %s", err, out) - } - - return nil -} - -func WindowsRemoveRoute(destination string) error { - // Parse destination to get the IP - ip, ipNet, err := net.ParseCIDR(destination) - if err != nil { - return fmt.Errorf("invalid destination address: %v", err) - } - - // Calculate the subnet mask - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - cmd := exec.Command("route", "delete", - ip.String(), - "mask", maskIP.String()) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route delete command failed: %v, output: %s", err, out) - } - - return nil -} - -func findUnusedUTUN() (string, error) { - ifaces, err := net.Interfaces() - if err != nil { - return "", fmt.Errorf("failed to list interfaces: %v", err) - } - used := make(map[int]bool) - re := regexp.MustCompile(`^utun(\d+)$`) - for _, iface := range ifaces { - if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { - if num, err := strconv.Atoi(matches[1]); err == nil { - used[num] = true - } - } - } - // Try utun0 up to utun255. - for i := 0; i < 256; i++ { - if !used[i] { - return fmt.Sprintf("utun%d", i), nil - } - } - return "", fmt.Errorf("no unused utun interface found") -} - -func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - logger.Info("Configuring darwin interface: %s", interfaceName) - - prefix, _ := ipNet.Mask.Size() - ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) - - cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) - } - - // Bring up the interface - cmd = exec.Command("ifconfig", interfaceName, "up") - logger.Info("Running command: %v", cmd) - - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) - } - - return nil -} - -func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - // Get the interface - link, err := netlink.LinkByName(interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) - } - - // Create the IP address attributes - addr := &netlink.Addr{ - IPNet: &net.IPNet{ - IP: ip, - Mask: ipNet.Mask, - }, - } - - // Add the IP address to the interface - if err := netlink.AddrAdd(link, addr); err != nil { - return fmt.Errorf("failed to add IP address: %v", err) - } - - // Bring up the interface - if err := netlink.LinkSetUp(link); err != nil { - return fmt.Errorf("failed to bring up interface: %v", err) - } - - return nil -} - -func DarwinAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "darwin" { - return nil - } - - var cmd *exec.Cmd - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) - } else if interfaceName != "" { - // Route via interface - cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route command failed: %v, output: %s", err, out) - } - - return nil -} - -func DarwinRemoveRoute(destination string) error { - if runtime.GOOS != "darwin" { - return nil - } - - cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route delete command failed: %v, output: %s", err, out) - } - - return nil -} - -func LinuxAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "linux" { - return nil - } - - var cmd *exec.Cmd - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("ip", "route", "add", destination, "via", gateway) - } else if interfaceName != "" { - // Route via interface - cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ip route command failed: %v, output: %s", err, out) - } - - return nil -} - -func LinuxRemoveRoute(destination string) error { - if runtime.GOOS != "linux" { - return nil - } - - cmd := exec.Command("ip", "route", "del", destination) - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out) - } - - return nil -} - -// addRouteForServerIP adds an OS-specific route for the server IP -func addRouteForServerIP(serverIP, interfaceName string) error { - if runtime.GOOS == "darwin" { - return DarwinAddRoute(serverIP, "", interfaceName) - } - // else if runtime.GOOS == "windows" { - // return WindowsAddRoute(serverIP, "", interfaceName) - // } else if runtime.GOOS == "linux" { - // return LinuxAddRoute(serverIP, "", interfaceName) - // } - return nil -} - -// removeRouteForServerIP removes an OS-specific route for the server IP -func removeRouteForServerIP(serverIP string) error { - if runtime.GOOS == "darwin" { - return DarwinRemoveRoute(serverIP) - } - // else if runtime.GOOS == "windows" { - // return WindowsRemoveRoute(serverIP) - // } else if runtime.GOOS == "linux" { - // return LinuxRemoveRoute(serverIP) - // } - return nil -} - -// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets -func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { - if remoteSubnets == "" { - return nil - } - - // Split remote subnets by comma and add routes for each one - subnets := strings.Split(remoteSubnets, ",") - for _, subnet := range subnets { - subnet = strings.TrimSpace(subnet) - if subnet == "" { - continue - } - - // Add route based on operating system - if runtime.GOOS == "darwin" { - if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { - logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "windows" { - if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { - logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "linux" { - if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { - logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) - return err - } - } - - logger.Info("Added route for remote subnet: %s", subnet) - } - return nil -} - -// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets -func removeRoutesForRemoteSubnets(remoteSubnets string) error { - if remoteSubnets == "" { - return nil - } - - // Split remote subnets by comma and remove routes for each one - subnets := strings.Split(remoteSubnets, ",") - for _, subnet := range subnets { - subnet = strings.TrimSpace(subnet) - if subnet == "" { - continue - } - - // Remove route based on operating system - if runtime.GOOS == "darwin" { - if err := DarwinRemoveRoute(subnet); err != nil { - logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "windows" { - if err := WindowsRemoveRoute(subnet); err != nil { - logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "linux" { - if err := LinuxRemoveRoute(subnet); err != nil { - logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) - return err - } - } - - logger.Info("Removed route for remote subnet: %s", subnet) - } - return nil -} diff --git a/config.go b/config.go index 8b3664f..4b1c824 100644 --- a/config.go +++ b/config.go @@ -8,35 +8,43 @@ import ( "path/filepath" "runtime" "strconv" + "strings" "time" ) // OlmConfig holds all configuration options for the Olm client type OlmConfig struct { // Connection settings - Endpoint string `json:"endpoint"` - ID string `json:"id"` - Secret string `json:"secret"` + Endpoint string `json:"endpoint"` + ID string `json:"id"` + Secret string `json:"secret"` + OrgID string `json:"org"` + UserToken string `json:"userToken"` // Network settings - MTU int `json:"mtu"` - DNS string `json:"dns"` - InterfaceName string `json:"interface"` + MTU int `json:"mtu"` + DNS string `json:"dns"` + UpstreamDNS []string `json:"upstreamDNS"` + InterfaceName string `json:"interface"` // Logging LogLevel string `json:"logLevel"` // HTTP server - EnableHTTP bool `json:"enableHttp"` + EnableAPI bool `json:"enableApi"` HTTPAddr string `json:"httpAddr"` + SocketPath string `json:"socketPath"` // Ping settings PingInterval string `json:"pingInterval"` PingTimeout string `json:"pingTimeout"` // Advanced - Holepunch bool `json:"holepunch"` - TlsClientCert string `json:"tlsClientCert"` + DisableHolepunch bool `json:"disableHolepunch"` + TlsClientCert string `json:"tlsClientCert"` + OverrideDNS bool `json:"overrideDNS"` + DisableRelay bool `json:"disableRelay"` + // DoNotCreateNewClient bool `json:"doNotCreateNewClient"` // Parsed values (not in JSON) PingIntervalDuration time.Duration `json:"-"` @@ -44,6 +52,8 @@ type OlmConfig struct { // Source tracking (not in JSON) sources map[string]string `json:"-"` + + Version string } // ConfigSource tracks where each config value came from @@ -58,29 +68,45 @@ const ( // DefaultConfig returns a config with default values func DefaultConfig() *OlmConfig { + // Set OS-specific socket path + var socketPath string + switch runtime.GOOS { + case "windows": + socketPath = "olm" + default: // darwin, linux, and others + socketPath = "/var/run/olm.sock" + } + config := &OlmConfig{ - MTU: 1280, - DNS: "8.8.8.8", - LogLevel: "INFO", - InterfaceName: "olm", - EnableHTTP: false, - HTTPAddr: ":9452", - PingInterval: "3s", - PingTimeout: "5s", - Holepunch: false, - sources: make(map[string]string), + MTU: 1280, + DNS: "8.8.8.8", + UpstreamDNS: []string{"8.8.8.8:53"}, + LogLevel: "INFO", + InterfaceName: "olm", + EnableAPI: false, + SocketPath: socketPath, + PingInterval: "3s", + PingTimeout: "5s", + DisableHolepunch: false, + // DoNotCreateNewClient: false, + sources: make(map[string]string), } // Track default sources config.sources["mtu"] = string(SourceDefault) config.sources["dns"] = string(SourceDefault) + config.sources["upstreamDNS"] = string(SourceDefault) config.sources["logLevel"] = string(SourceDefault) config.sources["interface"] = string(SourceDefault) - config.sources["enableHttp"] = string(SourceDefault) + config.sources["enableApi"] = string(SourceDefault) config.sources["httpAddr"] = string(SourceDefault) + config.sources["socketPath"] = string(SourceDefault) config.sources["pingInterval"] = string(SourceDefault) config.sources["pingTimeout"] = string(SourceDefault) - config.sources["holepunch"] = string(SourceDefault) + config.sources["disableHolepunch"] = string(SourceDefault) + config.sources["overrideDNS"] = string(SourceDefault) + config.sources["disableRelay"] = string(SourceDefault) + // config.sources["doNotCreateNewClient"] = string(SourceDefault) return config } @@ -175,6 +201,14 @@ func loadConfigFromEnv(config *OlmConfig) { config.Secret = val config.sources["secret"] = string(SourceEnv) } + if val := os.Getenv("ORG"); val != "" { + config.OrgID = val + config.sources["org"] = string(SourceEnv) + } + if val := os.Getenv("USER_TOKEN"); val != "" { + config.UserToken = val + config.sources["userToken"] = string(SourceEnv) + } if val := os.Getenv("MTU"); val != "" { if mtu, err := strconv.Atoi(val); err == nil { config.MTU = mtu @@ -187,6 +221,10 @@ func loadConfigFromEnv(config *OlmConfig) { config.DNS = val config.sources["dns"] = string(SourceEnv) } + if val := os.Getenv("UPSTREAM_DNS"); val != "" { + config.UpstreamDNS = []string{val} + config.sources["upstreamDNS"] = string(SourceEnv) + } if val := os.Getenv("LOG_LEVEL"); val != "" { config.LogLevel = val config.sources["logLevel"] = string(SourceEnv) @@ -207,14 +245,30 @@ func loadConfigFromEnv(config *OlmConfig) { config.PingTimeout = val config.sources["pingTimeout"] = string(SourceEnv) } - if val := os.Getenv("ENABLE_HTTP"); val == "true" { - config.EnableHTTP = true - config.sources["enableHttp"] = string(SourceEnv) + if val := os.Getenv("ENABLE_API"); val == "true" { + config.EnableAPI = true + config.sources["enableApi"] = string(SourceEnv) } - if val := os.Getenv("HOLEPUNCH"); val == "true" { - config.Holepunch = true - config.sources["holepunch"] = string(SourceEnv) + if val := os.Getenv("SOCKET_PATH"); val != "" { + config.SocketPath = val + config.sources["socketPath"] = string(SourceEnv) } + if val := os.Getenv("DISABLE_HOLEPUNCH"); val == "true" { + config.DisableHolepunch = true + config.sources["disableHolepunch"] = string(SourceEnv) + } + if val := os.Getenv("OVERRIDE_DNS"); val == "true" { + config.OverrideDNS = true + config.sources["overrideDNS"] = string(SourceEnv) + } + if val := os.Getenv("DISABLE_RELAY"); val == "true" { + config.DisableRelay = true + config.sources["disableRelay"] = string(SourceEnv) + } + // if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" { + // config.DoNotCreateNewClient = true + // config.sources["doNotCreateNewClient"] = string(SourceEnv) + // } } // loadConfigFromCLI loads configuration from command-line arguments @@ -223,33 +277,48 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { // Store original values to detect changes origValues := map[string]interface{}{ - "endpoint": config.Endpoint, - "id": config.ID, - "secret": config.Secret, - "mtu": config.MTU, - "dns": config.DNS, - "logLevel": config.LogLevel, - "interface": config.InterfaceName, - "httpAddr": config.HTTPAddr, - "pingInterval": config.PingInterval, - "pingTimeout": config.PingTimeout, - "enableHttp": config.EnableHTTP, - "holepunch": config.Holepunch, + "endpoint": config.Endpoint, + "id": config.ID, + "secret": config.Secret, + "org": config.OrgID, + "userToken": config.UserToken, + "mtu": config.MTU, + "dns": config.DNS, + "upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS), + "logLevel": config.LogLevel, + "interface": config.InterfaceName, + "httpAddr": config.HTTPAddr, + "socketPath": config.SocketPath, + "pingInterval": config.PingInterval, + "pingTimeout": config.PingTimeout, + "enableApi": config.EnableAPI, + "disableHolepunch": config.DisableHolepunch, + "overrideDNS": config.OverrideDNS, + "disableRelay": config.DisableRelay, + // "doNotCreateNewClient": config.DoNotCreateNewClient, } // Define flags serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server") serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID") serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret") + serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID") + serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)") serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use") serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use") + var upstreamDNSFlag string + serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface") serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')") + serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)") serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server") serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping") - serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests") - serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching") + serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests") + serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching") + serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings") + serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections") + // serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client") version := serviceFlags.Bool("version", false, "Print the version") showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit") @@ -259,6 +328,16 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { return false, false, err } + // Parse upstream DNS flag if provided + if upstreamDNSFlag != "" { + config.UpstreamDNS = []string{} + for _, dns := range splitComma(upstreamDNSFlag) { + if dns != "" { + config.UpstreamDNS = append(config.UpstreamDNS, dns) + } + } + } + // Track which values were changed by CLI args if config.Endpoint != origValues["endpoint"].(string) { config.sources["endpoint"] = string(SourceCLI) @@ -269,12 +348,21 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.Secret != origValues["secret"].(string) { config.sources["secret"] = string(SourceCLI) } + if config.OrgID != origValues["org"].(string) { + config.sources["org"] = string(SourceCLI) + } + if config.UserToken != origValues["userToken"].(string) { + config.sources["userToken"] = string(SourceCLI) + } if config.MTU != origValues["mtu"].(int) { config.sources["mtu"] = string(SourceCLI) } if config.DNS != origValues["dns"].(string) { config.sources["dns"] = string(SourceCLI) } + if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) { + config.sources["upstreamDNS"] = string(SourceCLI) + } if config.LogLevel != origValues["logLevel"].(string) { config.sources["logLevel"] = string(SourceCLI) } @@ -284,18 +372,30 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.HTTPAddr != origValues["httpAddr"].(string) { config.sources["httpAddr"] = string(SourceCLI) } + if config.SocketPath != origValues["socketPath"].(string) { + config.sources["socketPath"] = string(SourceCLI) + } if config.PingInterval != origValues["pingInterval"].(string) { config.sources["pingInterval"] = string(SourceCLI) } if config.PingTimeout != origValues["pingTimeout"].(string) { config.sources["pingTimeout"] = string(SourceCLI) } - if config.EnableHTTP != origValues["enableHttp"].(bool) { - config.sources["enableHttp"] = string(SourceCLI) + if config.EnableAPI != origValues["enableApi"].(bool) { + config.sources["enableApi"] = string(SourceCLI) } - if config.Holepunch != origValues["holepunch"].(bool) { - config.sources["holepunch"] = string(SourceCLI) + if config.DisableHolepunch != origValues["disableHolepunch"].(bool) { + config.sources["disableHolepunch"] = string(SourceCLI) } + if config.OverrideDNS != origValues["overrideDNS"].(bool) { + config.sources["overrideDNS"] = string(SourceCLI) + } + if config.DisableRelay != origValues["disableRelay"].(bool) { + config.sources["disableRelay"] = string(SourceCLI) + } + // if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) { + // config.sources["doNotCreateNewClient"] = string(SourceCLI) + // } return *version, *showConfig, nil } @@ -348,6 +448,14 @@ func mergeConfigs(dest, src *OlmConfig) { dest.Secret = src.Secret dest.sources["secret"] = string(SourceFile) } + if src.OrgID != "" { + dest.OrgID = src.OrgID + dest.sources["org"] = string(SourceFile) + } + if src.UserToken != "" { + dest.UserToken = src.UserToken + dest.sources["userToken"] = string(SourceFile) + } if src.MTU != 0 && src.MTU != 1280 { dest.MTU = src.MTU dest.sources["mtu"] = string(SourceFile) @@ -356,6 +464,10 @@ func mergeConfigs(dest, src *OlmConfig) { dest.DNS = src.DNS dest.sources["dns"] = string(SourceFile) } + if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" { + dest.UpstreamDNS = src.UpstreamDNS + dest.sources["upstreamDNS"] = string(SourceFile) + } if src.LogLevel != "" && src.LogLevel != "INFO" { dest.LogLevel = src.LogLevel dest.sources["logLevel"] = string(SourceFile) @@ -368,6 +480,14 @@ func mergeConfigs(dest, src *OlmConfig) { dest.HTTPAddr = src.HTTPAddr dest.sources["httpAddr"] = string(SourceFile) } + if src.SocketPath != "" { + // Check if it's not the default for any OS + isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm" + if !isDefault { + dest.SocketPath = src.SocketPath + dest.sources["socketPath"] = string(SourceFile) + } + } if src.PingInterval != "" && src.PingInterval != "3s" { dest.PingInterval = src.PingInterval dest.sources["pingInterval"] = string(SourceFile) @@ -381,14 +501,26 @@ func mergeConfigs(dest, src *OlmConfig) { dest.sources["tlsClientCert"] = string(SourceFile) } // For booleans, we always take the source value if explicitly set - if src.EnableHTTP { - dest.EnableHTTP = src.EnableHTTP - dest.sources["enableHttp"] = string(SourceFile) + if src.EnableAPI { + dest.EnableAPI = src.EnableAPI + dest.sources["enableApi"] = string(SourceFile) } - if src.Holepunch { - dest.Holepunch = src.Holepunch - dest.sources["holepunch"] = string(SourceFile) + if src.DisableHolepunch { + dest.DisableHolepunch = src.DisableHolepunch + dest.sources["disableHolepunch"] = string(SourceFile) } + if src.OverrideDNS { + dest.OverrideDNS = src.OverrideDNS + dest.sources["overrideDNS"] = string(SourceFile) + } + if src.DisableRelay { + dest.DisableRelay = src.DisableRelay + dest.sources["disableRelay"] = string(SourceFile) + } + // if src.DoNotCreateNewClient { + // dest.DoNotCreateNewClient = src.DoNotCreateNewClient + // dest.sources["doNotCreateNewClient"] = string(SourceFile) + // } } // SaveConfig saves the current configuration to the config file @@ -405,7 +537,7 @@ func SaveConfig(config *OlmConfig) error { func (c *OlmConfig) ShowConfig() { configPath := getOlmConfigPath() - fmt.Println("\n=== Olm Configuration ===\n") + fmt.Print("\n=== Olm Configuration ===\n\n") fmt.Printf("Config File: %s\n", configPath) // Check if config file exists @@ -416,7 +548,7 @@ func (c *OlmConfig) ShowConfig() { } fmt.Println("\n--- Configuration Values ---") - fmt.Println("(Format: Setting = Value [source])\n") + fmt.Print("(Format: Setting = Value [source])\n\n") // Helper to get source or default getSource := func(key string) string { @@ -445,21 +577,25 @@ func (c *OlmConfig) ShowConfig() { fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint")) fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id")) fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret")) + fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org")) + fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken")) // Network settings fmt.Println("\nNetwork:") fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu")) fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns")) + fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS")) fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface")) // Logging fmt.Println("\nLogging:") fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel")) - // HTTP server - fmt.Println("\nHTTP Server:") - fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp")) + // API server + fmt.Println("\nAPI Server:") + fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi")) fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr")) + fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath")) // Timing fmt.Println("\nTiming:") @@ -468,9 +604,12 @@ func (c *OlmConfig) ShowConfig() { // Advanced fmt.Println("\nAdvanced:") - fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch")) + fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch")) + fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS")) + fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay")) + // fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient")) if c.TlsClientCert != "" { - fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert")) + fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert")) } // Source legend @@ -482,3 +621,16 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nPriority: cli > environment > file > default") fmt.Println() } + +// splitComma splits a comma-separated string into a slice of trimmed strings +func splitComma(s string) []string { + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/create_test_creds.py b/create_test_creds.py new file mode 100644 index 0000000..2a0eb1b --- /dev/null +++ b/create_test_creds.py @@ -0,0 +1,43 @@ + +import requests + +def create_olm(base_url, user_token, olm_name, user_id): + url = f"{base_url}/api/v1/user/{user_id}/olm" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": "pangolin-cli", + "X-CSRF-Token": "x-csrf-protection", + "Cookie": f"p_session_token={user_token}" + } + payload = {"name": olm_name} + response = requests.put(url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + print(f"Response Data: {data}") + +def create_client(base_url, user_token, client_name): + url = f"{base_url}/api/v1/api/clients" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": "pangolin-cli", + "X-CSRF-Token": "x-csrf-protection", + "Cookie": f"p_session_token={user_token}" + } + payload = {"name": client_name} + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + print(f"Response Data: {data}") + +if __name__ == "__main__": + # Example usage + base_url = input("Enter base URL (e.g., http://localhost:3000): ") + user_token = input("Enter user token: ") + user_id = input("Enter user ID: ") + olm_name = input("Enter OLM name: ") + client_name = input("Enter client name: ") + + create_olm(base_url, user_token, olm_name, user_id) + # client_id = create_client(base_url, user_token, client_name) \ No newline at end of file diff --git a/device/middle_device.go b/device/middle_device.go new file mode 100644 index 0000000..b031871 --- /dev/null +++ b/device/middle_device.go @@ -0,0 +1,331 @@ +package device + +import ( + "net/netip" + "os" + "sync" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/tun" +) + +// PacketHandler processes intercepted packets and returns true if packet should be dropped +type PacketHandler func(packet []byte) bool + +// FilterRule defines a rule for packet filtering +type FilterRule struct { + DestIP netip.Addr + Handler PacketHandler +} + +// MiddleDevice wraps a TUN device with packet filtering capabilities +type MiddleDevice struct { + tun.Device + rules []FilterRule + mutex sync.RWMutex + readCh chan readResult + injectCh chan []byte + closed chan struct{} +} + +type readResult struct { + bufs [][]byte + sizes []int + offset int + n int + err error +} + +// NewMiddleDevice creates a new filtered TUN device wrapper +func NewMiddleDevice(device tun.Device) *MiddleDevice { + d := &MiddleDevice{ + Device: device, + rules: make([]FilterRule, 0), + readCh: make(chan readResult), + injectCh: make(chan []byte, 100), + closed: make(chan struct{}), + } + go d.pump() + return d +} + +func (d *MiddleDevice) pump() { + const defaultOffset = 16 + batchSize := d.Device.BatchSize() + logger.Debug("MiddleDevice: pump started") + + for { + // Check closed first with priority + select { + case <-d.closed: + logger.Debug("MiddleDevice: pump exiting due to closed channel") + return + default: + } + + // Allocate buffers for reading + // We allocate new buffers for each read to avoid race conditions + // since we pass them to the channel + bufs := make([][]byte, batchSize) + sizes := make([]int, batchSize) + for i := range bufs { + bufs[i] = make([]byte, 2048) // Standard MTU + headroom + } + + n, err := d.Device.Read(bufs, sizes, defaultOffset) + + // Check closed again after read returns + select { + case <-d.closed: + logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)") + return + default: + } + + // Now try to send the result + select { + case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: + case <-d.closed: + logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)") + return + } + + if err != nil { + logger.Debug("MiddleDevice: pump exiting due to read error: %v", err) + return + } + } +} + +// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) +func (d *MiddleDevice) InjectOutbound(packet []byte) { + select { + case d.injectCh <- packet: + case <-d.closed: + } +} + +// AddRule adds a packet filtering rule +func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { + d.mutex.Lock() + defer d.mutex.Unlock() + d.rules = append(d.rules, FilterRule{ + DestIP: destIP, + Handler: handler, + }) +} + +// RemoveRule removes all rules for a given destination IP +func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { + d.mutex.Lock() + defer d.mutex.Unlock() + newRules := make([]FilterRule, 0, len(d.rules)) + for _, rule := range d.rules { + if rule.DestIP != destIP { + newRules = append(newRules, rule) + } + } + d.rules = newRules +} + +// Close stops the device +func (d *MiddleDevice) Close() error { + select { + case <-d.closed: + // Already closed + return nil + default: + logger.Debug("MiddleDevice: Closing, signaling closed channel") + close(d.closed) + } + logger.Debug("MiddleDevice: Closing underlying TUN device") + err := d.Device.Close() + logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err) + return err +} + +// extractDestIP extracts destination IP from packet (fast path) +func extractDestIP(packet []byte) (netip.Addr, bool) { + if len(packet) < 20 { + return netip.Addr{}, false + } + + version := packet[0] >> 4 + + switch version { + case 4: + if len(packet) < 20 { + return netip.Addr{}, false + } + // Destination IP is at bytes 16-19 for IPv4 + ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]}) + return ip, true + case 6: + if len(packet) < 40 { + return netip.Addr{}, false + } + // Destination IP is at bytes 24-39 for IPv6 + var ip16 [16]byte + copy(ip16[:], packet[24:40]) + ip := netip.AddrFrom16(ip16) + return ip, true + } + + return netip.Addr{}, false +} + +// Read intercepts packets going UP from the TUN device (towards WireGuard) +func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + // Check if already closed first (non-blocking) + select { + case <-d.closed: + logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)") + return 0, os.ErrClosed + default: + } + + // Now block waiting for data + select { + case res := <-d.readCh: + if res.err != nil { + logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) + return 0, res.err + } + + // Copy packets from result to provided buffers + count := 0 + for i := 0; i < res.n && i < len(bufs); i++ { + // Handle offset mismatch if necessary + // We assume the pump used defaultOffset (16) + // If caller asks for different offset, we need to shift + src := res.bufs[i] + srcOffset := res.offset + srcSize := res.sizes[i] + + // Calculate where the packet data starts and ends in src + pktData := src[srcOffset : srcOffset+srcSize] + + // Ensure dest buffer is large enough + if len(bufs[i]) < offset+len(pktData) { + continue // Skip if buffer too small + } + + copy(bufs[i][offset:], pktData) + sizes[i] = len(pktData) + count++ + } + n = count + + case pkt := <-d.injectCh: + if len(bufs) == 0 { + return 0, nil + } + if len(bufs[0]) < offset+len(pkt) { + return 0, nil // Buffer too small + } + copy(bufs[0][offset:], pkt) + sizes[0] = len(pkt) + n = 1 + + case <-d.closed: + logger.Debug("MiddleDevice: Read returning os.ErrClosed") + return 0, os.ErrClosed // Signal that device is closed + } + + d.mutex.RLock() + rules := d.rules + d.mutex.RUnlock() + + if len(rules) == 0 { + return n, nil + } + + // Process packets and filter out handled ones + writeIdx := 0 + for readIdx := 0; readIdx < n; readIdx++ { + packet := bufs[readIdx][offset : offset+sizes[readIdx]] + + destIP, ok := extractDestIP(packet) + if !ok { + // Can't parse, keep packet + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ + continue + } + + // Check if packet matches any rule + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + // Packet was handled and should be dropped + handled = true + break + } + } + } + + if !handled { + // Keep packet + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ + } + } + + return writeIdx, err +} + +// Write intercepts packets going DOWN to the TUN device (from WireGuard) +func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { + d.mutex.RLock() + rules := d.rules + d.mutex.RUnlock() + + if len(rules) == 0 { + return d.Device.Write(bufs, offset) + } + + // Filter packets going down + filteredBufs := make([][]byte, 0, len(bufs)) + for _, buf := range bufs { + if len(buf) <= offset { + continue + } + + packet := buf[offset:] + destIP, ok := extractDestIP(packet) + if !ok { + // Can't parse, keep packet + filteredBufs = append(filteredBufs, buf) + continue + } + + // Check if packet matches any rule + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + // Packet was handled and should be dropped + handled = true + break + } + } + } + + if !handled { + filteredBufs = append(filteredBufs, buf) + } + } + + if len(filteredBufs) == 0 { + return len(bufs), nil // All packets were handled + } + + return d.Device.Write(filteredBufs, offset) +} diff --git a/device/middle_device_test.go b/device/middle_device_test.go new file mode 100644 index 0000000..58cb88f --- /dev/null +++ b/device/middle_device_test.go @@ -0,0 +1,102 @@ +package device + +import ( + "net/netip" + "testing" + + "github.com/fosrl/newt/util" +) + +func TestExtractDestIP(t *testing.T) { + tests := []struct { + name string + packet []byte + wantIP string + wantOk bool + }{ + { + name: "IPv4 packet", + packet: []byte{ + 0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00, + 0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, + 0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30 + }, + wantIP: "10.30.30.30", + wantOk: true, + }, + { + name: "Too short packet", + packet: []byte{0x45, 0x00}, + wantIP: "", + wantOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIP, gotOk := extractDestIP(tt.packet) + if gotOk != tt.wantOk { + t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk) + return + } + if tt.wantOk { + wantAddr := netip.MustParseAddr(tt.wantIP) + if gotIP != wantAddr { + t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr) + } + } + }) + } +} + +func TestGetProtocol(t *testing.T) { + tests := []struct { + name string + packet []byte + wantProto uint8 + wantOk bool + }{ + { + name: "UDP packet", + packet: []byte{ + 0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00, + 0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9 + 0x0a, 0x1e, 0x1e, 0x1e, + }, + wantProto: 17, + wantOk: true, + }, + { + name: "Too short", + packet: []byte{0x45, 0x00}, + wantProto: 0, + wantOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotProto, gotOk := util.GetProtocol(tt.packet) + if gotOk != tt.wantOk { + t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk) + return + } + if gotProto != tt.wantProto { + t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto) + } + }) + } +} + +func BenchmarkExtractDestIP(b *testing.B) { + packet := []byte{ + 0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00, + 0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, + 0x0a, 0x1e, 0x1e, 0x1e, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + extractDestIP(packet) + } +} diff --git a/device/tun_unix.go b/device/tun_unix.go new file mode 100644 index 0000000..c9bab60 --- /dev/null +++ b/device/tun_unix.go @@ -0,0 +1,44 @@ +//go:build !windows + +package device + +import ( + "net" + "os" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + logger.Error("Unable to dup tun fd: %v", err) + return nil, err + } + + err = unix.SetNonblock(dupTunFd, true) + if err != nil { + unix.Close(dupTunFd) + return nil, err + } + + file := os.NewFile(uintptr(dupTunFd), "/dev/tun") + device, err := tun.CreateTUNFromFile(file, mtuInt) + if err != nil { + file.Close() + return nil, err + } + + return device, nil +} + +func UapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) +} + +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) +} diff --git a/windows.go b/device/tun_windows.go similarity index 61% rename from windows.go rename to device/tun_windows.go index 032096b..edcd6f6 100644 --- a/windows.go +++ b/device/tun_windows.go @@ -1,6 +1,6 @@ //go:build windows -package main +package device import ( "errors" @@ -11,15 +11,15 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { return nil, errors.New("CreateTUNFromFile not supported on Windows") } -func uapiOpen(interfaceName string) (*os.File, error) { +func UapiOpen(interfaceName string) (*os.File, error) { return nil, nil } -func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { // On Windows, UAPIListen only takes one parameter return ipc.UAPIListen(interfaceName) } diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go new file mode 100644 index 0000000..d0ed7b3 --- /dev/null +++ b/dns/dns_proxy.go @@ -0,0 +1,457 @@ +package dns + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/device" + "github.com/miekg/dns" + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +const ( + DNSPort = 53 +) + +// DNSProxy implements a DNS proxy using gvisor netstack +type DNSProxy struct { + stack *stack.Stack + ep *channel.Endpoint + proxyIP netip.Addr + upstreamDNS []string + mtu int + tunDevice tun.Device // Direct reference to underlying TUN device for responses + middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering + recordStore *DNSRecordStore // Local DNS records + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// NewDNSProxy creates a new DNS proxy +func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) { + proxyIP, err := PickIPFromSubnet(utilitySubnet) + if err != nil { + return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) + } + + if len(upstreamDns) == 0 { + return nil, fmt.Errorf("at least one upstream DNS server must be specified") + } + + ctx, cancel := context.WithCancel(context.Background()) + + proxy := &DNSProxy{ + proxyIP: proxyIP, + mtu: mtu, + tunDevice: tunDevice, + middleDevice: middleDevice, + upstreamDNS: upstreamDns, + recordStore: NewDNSRecordStore(), + ctx: ctx, + cancel: cancel, + } + + // Create gvisor netstack + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + proxy.ep = channel.New(256, uint32(mtu), "") + proxy.stack = stack.New(stackOpts) + + // Create NIC + if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil { + return nil, fmt.Errorf("failed to create NIC: %v", err) + } + + // Add IP address + // Parse the proxy IP to get the octets + ipBytes := proxyIP.As4() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), + } + + if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("failed to add protocol address: %v", err) + } + + // Add default route + proxy.stack.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }) + + return proxy, nil +} + +// Start starts the DNS proxy and registers with the filter +func (p *DNSProxy) Start() error { + // Install packet filter rule + p.middleDevice.AddRule(p.proxyIP, p.handlePacket) + + // Start DNS listener + p.wg.Add(2) + go p.runDNSListener() + go p.runPacketSender() + + logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort) + return nil +} + +// Stop stops the DNS proxy +func (p *DNSProxy) Stop() { + if p.middleDevice != nil { + p.middleDevice.RemoveRule(p.proxyIP) + } + p.cancel() + + // Close the endpoint first to unblock any pending Read() calls in runPacketSender + if p.ep != nil { + p.ep.Close() + } + + p.wg.Wait() + + if p.stack != nil { + p.stack.Close() + } + + logger.Info("DNS proxy stopped") +} + +func (p *DNSProxy) GetProxyIP() netip.Addr { + return p.proxyIP +} + +// handlePacket is called by the filter for packets destined to DNS proxy IP +func (p *DNSProxy) handlePacket(packet []byte) bool { + if len(packet) < 20 { + return false // Don't drop, malformed + } + + // Quick check for UDP port 53 + proto, ok := util.GetProtocol(packet) + if !ok || proto != 17 { // 17 = UDP + return false // Not UDP, don't handle + } + + port, ok := util.GetDestPort(packet) + if !ok || port != DNSPort { + return false // Not DNS port + } + + // Inject packet into our netstack + version := packet[0] >> 4 + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + + switch version { + case 4: + p.ep.InjectInbound(ipv4.ProtocolNumber, pkb) + case 6: + p.ep.InjectInbound(ipv6.ProtocolNumber, pkb) + default: + pkb.DecRef() + return false + } + + pkb.DecRef() + return true // Drop packet from normal path +} + +// runDNSListener listens for DNS queries on the netstack +func (p *DNSProxy) runDNSListener() { + defer p.wg.Done() + + // Create UDP listener using gonet + // Parse the proxy IP to get the octets + ipBytes := p.proxyIP.As4() + laddr := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4(ipBytes), + Port: DNSPort, + } + + udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber) + if err != nil { + logger.Error("Failed to create DNS listener: %v", err) + return + } + defer udpConn.Close() + + logger.Debug("DNS proxy listening on netstack") + + // Handle DNS queries + buf := make([]byte, 4096) + for { + select { + case <-p.ctx.Done(): + return + default: + } + + udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, remoteAddr, err := udpConn.ReadFrom(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + if p.ctx.Err() != nil { + return + } + logger.Error("DNS read error: %v", err) + continue + } + + query := make([]byte, n) + copy(query, buf[:n]) + + // Handle query in background + go p.handleDNSQuery(udpConn, query, remoteAddr) + } +} + +// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream +func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) { + // Parse the DNS query + msg := new(dns.Msg) + if err := msg.Unpack(queryData); err != nil { + logger.Error("Failed to parse DNS query: %v", err) + return + } + + if len(msg.Question) == 0 { + logger.Debug("DNS query has no questions") + return + } + + question := msg.Question[0] + logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype]) + + // Check if we have local records for this query + var response *dns.Msg + if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { + response = p.checkLocalRecords(msg, question) + } + + // If no local records, forward to upstream + if response == nil { + logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) + response = p.forwardToUpstream(msg) + } + + if response == nil { + logger.Error("Failed to get DNS response for %s", question.Name) + return + } + + // Pack and send response + responseData, err := response.Pack() + if err != nil { + logger.Error("Failed to pack DNS response: %v", err) + return + } + + _, err = udpConn.WriteTo(responseData, clientAddr) + if err != nil { + logger.Error("Failed to send DNS response: %v", err) + } +} + +// checkLocalRecords checks if we have local records for the query +func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg { + var recordType RecordType + if question.Qtype == dns.TypeA { + recordType = RecordTypeA + } else if question.Qtype == dns.TypeAAAA { + recordType = RecordTypeAAAA + } else { + return nil + } + + ips := p.recordStore.GetRecords(question.Name, recordType) + if len(ips) == 0 { + return nil + } + + logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) + + // Create response message + response := new(dns.Msg) + response.SetReply(query) + response.Authoritative = true + + // Add answer records + for _, ip := range ips { + var rr dns.RR + if question.Qtype == dns.TypeA { + rr = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, // 5 minutes + }, + A: ip.To4(), + } + } else { // TypeAAAA + rr = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, // 5 minutes + }, + AAAA: ip.To16(), + } + } + response.Answer = append(response.Answer, rr) + } + + return response +} + +// forwardToUpstream forwards a DNS query to upstream DNS servers +func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg { + // Try primary DNS server + response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second) + if err != nil && len(p.upstreamDNS) > 1 { + // Try secondary DNS server + logger.Debug("Primary DNS failed, trying secondary: %v", err) + response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second) + if err != nil { + logger.Error("Both DNS servers failed: %v", err) + return nil + } + } + return response +} + +// queryUpstream sends a DNS query to upstream server using miekg/dns +func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) { + client := &dns.Client{ + Timeout: timeout, + } + + response, _, err := client.Exchange(query, server) + if err != nil { + return nil, err + } + + return response, nil +} + +// runPacketSender sends packets from netstack back to TUN +func (p *DNSProxy) runPacketSender() { + defer p.wg.Done() + + // MessageTransportHeaderSize is the offset used by WireGuard device + // for reading/writing packets to the TUN interface + const offset = 16 + + for { + select { + case <-p.ctx.Done(): + return + default: + } + + // Read packets from netstack endpoint + pkt := p.ep.Read() + if pkt == nil { + // No packet available, small sleep to avoid busy loop + time.Sleep(1 * time.Millisecond) + continue + } + + // Extract packet data as slices + slices := pkt.AsSlices() + if len(slices) > 0 { + // Flatten all slices into a single packet buffer + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + // Allocate buffer with offset space for WireGuard transport header + // The first 'offset' bytes are reserved for the transport header + buf := make([]byte, offset+totalSize) + + // Copy packet data after the offset + pos := offset + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Write packet to TUN device + // offset=16 indicates packet data starts at position 16 in the buffer + _, err := p.tunDevice.Write([][]byte{buf}, offset) + if err != nil { + logger.Error("Failed to write DNS response to TUN: %v", err) + } + } + + pkt.DecRef() + } +} + +// AddDNSRecord adds a DNS record to the local store +// domain should be a domain name (e.g., "example.com" or "example.com.") +// ip should be a valid IPv4 or IPv6 address +func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { + return p.recordStore.AddRecord(domain, ip) +} + +// RemoveDNSRecord removes a DNS record from the local store +// If ip is nil, removes all records for the domain +func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) { + p.recordStore.RemoveRecord(domain, ip) +} + +// GetDNSRecords returns all IP addresses for a domain and record type +func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { + return p.recordStore.GetRecords(domain, recordType) +} + +// ClearDNSRecords removes all DNS records from the local store +func (p *DNSProxy) ClearDNSRecords() { + p.recordStore.Clear() +} + +func PickIPFromSubnet(subnet string) (netip.Addr, error) { + // given a subnet in CIDR notation, pick the first usable IP + prefix, err := netip.ParsePrefix(subnet) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err) + } + + // Pick the first usable IP address from the subnet + ip := prefix.Addr().Next() + if !ip.IsValid() { + return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet) + } + + return ip, nil +} diff --git a/dns/dns_records.go b/dns/dns_records.go new file mode 100644 index 0000000..8d57d68 --- /dev/null +++ b/dns/dns_records.go @@ -0,0 +1,166 @@ +package dns + +import ( + "net" + "sync" + + "github.com/miekg/dns" +) + +// RecordType represents the type of DNS record +type RecordType uint16 + +const ( + RecordTypeA RecordType = RecordType(dns.TypeA) + RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA) +) + +// DNSRecordStore manages local DNS records for A and AAAA queries +type DNSRecordStore struct { + mu sync.RWMutex + aRecords map[string][]net.IP // domain -> list of IPv4 addresses + aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses +} + +// NewDNSRecordStore creates a new DNS record store +func NewDNSRecordStore() *DNSRecordStore { + return &DNSRecordStore{ + aRecords: make(map[string][]net.IP), + aaaaRecords: make(map[string][]net.IP), + } +} + +// AddRecord adds a DNS record mapping (A or AAAA) +// domain should be in FQDN format (e.g., "example.com.") +// ip should be a valid IPv4 or IPv6 address +func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Ensure domain ends with a dot (FQDN format) + if len(domain) == 0 || domain[len(domain)-1] != '.' { + domain = domain + "." + } + + // Normalize domain to lowercase + domain = dns.Fqdn(domain) + + if ip.To4() != nil { + // IPv4 address + s.aRecords[domain] = append(s.aRecords[domain], ip) + } else if ip.To16() != nil { + // IPv6 address + s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) + } else { + return &net.ParseError{Type: "IP address", Text: ip.String()} + } + + return nil +} + +// RemoveRecord removes a specific DNS record mapping +// If ip is nil, removes all records for the domain +func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { + s.mu.Lock() + defer s.mu.Unlock() + + // Ensure domain ends with a dot (FQDN format) + if len(domain) == 0 || domain[len(domain)-1] != '.' { + domain = domain + "." + } + + // Normalize domain to lowercase + domain = dns.Fqdn(domain) + + if ip == nil { + // Remove all records for this domain + delete(s.aRecords, domain) + delete(s.aaaaRecords, domain) + return + } + + if ip.To4() != nil { + // Remove specific IPv4 address + if ips, ok := s.aRecords[domain]; ok { + s.aRecords[domain] = removeIP(ips, ip) + if len(s.aRecords[domain]) == 0 { + delete(s.aRecords, domain) + } + } + } else if ip.To16() != nil { + // Remove specific IPv6 address + if ips, ok := s.aaaaRecords[domain]; ok { + s.aaaaRecords[domain] = removeIP(ips, ip) + if len(s.aaaaRecords[domain]) == 0 { + delete(s.aaaaRecords, domain) + } + } + } +} + +// GetRecords returns all IP addresses for a domain and record type +func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { + s.mu.RLock() + defer s.mu.RUnlock() + + // Normalize domain to lowercase FQDN + domain = dns.Fqdn(domain) + + var records []net.IP + switch recordType { + case RecordTypeA: + if ips, ok := s.aRecords[domain]; ok { + // Return a copy to prevent external modifications + records = make([]net.IP, len(ips)) + copy(records, ips) + } + case RecordTypeAAAA: + if ips, ok := s.aaaaRecords[domain]; ok { + // Return a copy to prevent external modifications + records = make([]net.IP, len(ips)) + copy(records, ips) + } + } + + return records +} + +// HasRecord checks if a domain has any records of the specified type +func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + // Normalize domain to lowercase FQDN + domain = dns.Fqdn(domain) + + switch recordType { + case RecordTypeA: + _, ok := s.aRecords[domain] + return ok + case RecordTypeAAAA: + _, ok := s.aaaaRecords[domain] + return ok + } + + return false +} + +// Clear removes all records from the store +func (s *DNSRecordStore) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.aRecords = make(map[string][]net.IP) + s.aaaaRecords = make(map[string][]net.IP) +} + +// removeIP is a helper function to remove a specific IP from a slice +func removeIP(ips []net.IP, toRemove net.IP) []net.IP { + result := make([]net.IP, 0, len(ips)) + for _, ip := range ips { + if !ip.Equal(toRemove) { + result = append(result, ip) + } + } + return result +} diff --git a/dns/override/dns_override_darwin.go b/dns/override/dns_override_darwin.go new file mode 100644 index 0000000..6ccc3fb --- /dev/null +++ b/dns/override/dns_override_darwin.go @@ -0,0 +1,68 @@ +//go:build darwin && !ios + +package olm + +import ( + "fmt" + "net/netip" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/dns" + platform "github.com/fosrl/olm/dns/platform" +) + +var configurator platform.DNSConfigurator + +// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS +// Uses scutil for DNS configuration +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + if dnsProxy == nil { + return fmt.Errorf("DNS proxy is nil") + } + + var err error + configurator, err = platform.NewDarwinDNSConfigurator() + if err != nil { + return fmt.Errorf("failed to create Darwin DNS configurator: %w", err) + } + + logger.Info("Using Darwin scutil DNS configurator") + + // Get current DNS servers before changing + currentDNS, err := configurator.GetCurrentDNS() + if err != nil { + logger.Warn("Could not get current DNS: %v", err) + } else { + logger.Info("Current DNS servers: %v", currentDNS) + } + + // Set new DNS servers to point to our proxy + newDNS := []netip.Addr{ + dnsProxy.GetProxyIP(), + } + + logger.Info("Setting DNS servers to: %v", newDNS) + originalDNS, err := configurator.SetDNS(newDNS) + if err != nil { + return fmt.Errorf("failed to set DNS: %w", err) + } + + logger.Info("Original DNS servers backed up: %v", originalDNS) + return nil +} + +// RestoreDNSOverride restores the original DNS configuration +func RestoreDNSOverride() error { + if configurator == nil { + logger.Debug("No DNS configurator to restore") + return nil + } + + logger.Info("Restoring original DNS configuration") + if err := configurator.RestoreDNS(); err != nil { + return fmt.Errorf("failed to restore DNS: %w", err) + } + + logger.Info("DNS configuration restored successfully") + return nil +} diff --git a/dns/override/dns_override_unix.go b/dns/override/dns_override_unix.go new file mode 100644 index 0000000..c3b31e8 --- /dev/null +++ b/dns/override/dns_override_unix.go @@ -0,0 +1,105 @@ +//go:build (linux && !android) || freebsd + +package olm + +import ( + "fmt" + "net/netip" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/dns" + platform "github.com/fosrl/olm/dns/platform" +) + +var configurator platform.DNSConfigurator + +// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD +// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + if dnsProxy == nil { + return fmt.Errorf("DNS proxy is nil") + } + + var err error + + // Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability + managerType := platform.DetectDNSManager(interfaceName) + logger.Info("Detected DNS manager: %s", managerType.String()) + + // Create configurator based on detected manager + switch managerType { + case platform.SystemdResolvedManager: + configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName) + if err == nil { + logger.Info("Using systemd-resolved DNS configurator") + return setDNS(dnsProxy, configurator) + } + logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err) + + case platform.NetworkManagerManager: + configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName) + if err == nil { + logger.Info("Using NetworkManager DNS configurator") + return setDNS(dnsProxy, configurator) + } + logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err) + + case platform.ResolvconfManager: + configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName) + if err == nil { + logger.Info("Using resolvconf DNS configurator") + return setDNS(dnsProxy, configurator) + } + logger.Warn("Failed to create resolvconf configurator: %v, falling back", err) + } + + // Fall back to direct file manipulation + configurator, err = platform.NewFileDNSConfigurator() + if err != nil { + return fmt.Errorf("failed to create file DNS configurator: %w", err) + } + + logger.Info("Using file-based DNS configurator") + return setDNS(dnsProxy, configurator) +} + +// setDNS is a helper function to set DNS and log the results +func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { + // Get current DNS servers before changing + currentDNS, err := conf.GetCurrentDNS() + if err != nil { + logger.Warn("Could not get current DNS: %v", err) + } else { + logger.Info("Current DNS servers: %v", currentDNS) + } + + // Set new DNS servers to point to our proxy + newDNS := []netip.Addr{ + dnsProxy.GetProxyIP(), + } + + logger.Info("Setting DNS servers to: %v", newDNS) + originalDNS, err := conf.SetDNS(newDNS) + if err != nil { + return fmt.Errorf("failed to set DNS: %w", err) + } + + logger.Info("Original DNS servers backed up: %v", originalDNS) + return nil +} + +// RestoreDNSOverride restores the original DNS configuration +func RestoreDNSOverride() error { + if configurator == nil { + logger.Debug("No DNS configurator to restore") + return nil + } + + logger.Info("Restoring original DNS configuration") + if err := configurator.RestoreDNS(); err != nil { + return fmt.Errorf("failed to restore DNS: %w", err) + } + + logger.Info("DNS configuration restored successfully") + return nil +} diff --git a/dns/override/dns_override_windows.go b/dns/override/dns_override_windows.go new file mode 100644 index 0000000..a564079 --- /dev/null +++ b/dns/override/dns_override_windows.go @@ -0,0 +1,68 @@ +//go:build windows + +package olm + +import ( + "fmt" + "net/netip" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/dns" + platform "github.com/fosrl/olm/dns/platform" +) + +var configurator platform.DNSConfigurator + +// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows +// Uses registry-based configuration (automatically extracts interface GUID) +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + if dnsProxy == nil { + return fmt.Errorf("DNS proxy is nil") + } + + var err error + configurator, err = platform.NewWindowsDNSConfigurator(interfaceName) + if err != nil { + return fmt.Errorf("failed to create Windows DNS configurator: %w", err) + } + + logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName) + + // Get current DNS servers before changing + currentDNS, err := configurator.GetCurrentDNS() + if err != nil { + logger.Warn("Could not get current DNS: %v", err) + } else { + logger.Info("Current DNS servers: %v", currentDNS) + } + + // Set new DNS servers to point to our proxy + newDNS := []netip.Addr{ + dnsProxy.GetProxyIP(), + } + + logger.Info("Setting DNS servers to: %v", newDNS) + originalDNS, err := configurator.SetDNS(newDNS) + if err != nil { + return fmt.Errorf("failed to set DNS: %w", err) + } + + logger.Info("Original DNS servers backed up: %v", originalDNS) + return nil +} + +// RestoreDNSOverride restores the original DNS configuration +func RestoreDNSOverride() error { + if configurator == nil { + logger.Debug("No DNS configurator to restore") + return nil + } + + logger.Info("Restoring original DNS configuration") + if err := configurator.RestoreDNS(); err != nil { + return fmt.Errorf("failed to restore DNS: %w", err) + } + + logger.Info("DNS configuration restored successfully") + return nil +} diff --git a/dns/platform/darwin.go b/dns/platform/darwin.go new file mode 100644 index 0000000..a31f3a4 --- /dev/null +++ b/dns/platform/darwin.go @@ -0,0 +1,268 @@ +//go:build darwin && !ios + +package dns + +import ( + "bufio" + "bytes" + "fmt" + "net/netip" + "os/exec" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" +) + +const ( + scutilPath = "/usr/sbin/scutil" + dscacheutilPath = "/usr/bin/dscacheutil" + + dnsStateKeyFormat = "State:/Network/Service/Olm-%s/DNS" + globalIPv4State = "State:/Network/Global/IPv4" + primaryServiceFormat = "State:/Network/Service/%s/DNS" + + keySupplementalMatchDomains = "SupplementalMatchDomains" + keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" + keyServerAddresses = "ServerAddresses" + keyServerPort = "ServerPort" + arraySymbol = "* " + digitSymbol = "# " +) + +// DarwinDNSConfigurator manages DNS settings on macOS using scutil +type DarwinDNSConfigurator struct { + createdKeys map[string]struct{} + originalState *DNSState +} + +// NewDarwinDNSConfigurator creates a new macOS DNS configurator +func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) { + return &DarwinDNSConfigurator{ + createdKeys: make(map[string]struct{}), + }, nil +} + +// Name returns the configurator name +func (d *DarwinDNSConfigurator) Name() string { + return "darwin-scutil" +} + +// SetDNS sets the DNS servers and returns the original servers +func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) { + // Get current DNS settings before overriding + originalServers, err := d.GetCurrentDNS() + if err != nil { + return nil, fmt.Errorf("get current DNS: %w", err) + } + + // Store original state + d.originalState = &DNSState{ + OriginalServers: originalServers, + ConfiguratorName: d.Name(), + } + + // Set new DNS servers + if err := d.applyDNSServers(servers); err != nil { + return nil, fmt.Errorf("apply DNS servers: %w", err) + } + + // Flush DNS cache + if err := d.flushDNSCache(); err != nil { + // Non-fatal, just log + fmt.Printf("warning: failed to flush DNS cache: %v\n", err) + } + + return originalServers, nil +} + +// RestoreDNS restores the original DNS configuration +func (d *DarwinDNSConfigurator) RestoreDNS() error { + // Remove all created keys + for key := range d.createdKeys { + if err := d.removeKey(key); err != nil { + return fmt.Errorf("remove key %s: %w", key, err) + } + } + + // Flush DNS cache + if err := d.flushDNSCache(); err != nil { + fmt.Printf("warning: failed to flush DNS cache: %v\n", err) + } + + return nil +} + +// GetCurrentDNS returns the currently configured DNS servers +func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) { + primaryServiceKey, err := d.getPrimaryServiceKey() + if err != nil || primaryServiceKey == "" { + return nil, fmt.Errorf("get primary service: %w", err) + } + + dnsKey := fmt.Sprintf(primaryServiceFormat, primaryServiceKey) + cmd := fmt.Sprintf("show %s\n", dnsKey) + + output, err := d.runScutil(cmd) + if err != nil { + return nil, fmt.Errorf("run scutil: %w", err) + } + + servers := d.parseServerAddresses(output) + return servers, nil +} + +// applyDNSServers applies the DNS server configuration +func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error { + if len(servers) == 0 { + return fmt.Errorf("no DNS servers provided") + } + + key := fmt.Sprintf(dnsStateKeyFormat, "Override") + + // Use SupplementalMatchDomains with empty string to match ALL domains + // This is the key to making DNS override work on macOS + // Setting SupplementalMatchDomainsNoSearch to 0 enables search domain behavior + err := d.addDNSState(key, "\"\"", servers[0], 53, true) + if err != nil { + return fmt.Errorf("set DNS servers: %w", err) + } + + d.createdKeys[key] = struct{}{} + return nil +} + +// addDNSState adds a DNS state entry with the specified configuration +func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error { + noSearch := "1" + if enableSearch { + noSearch = "0" + } + + // Build the scutil command following NetBird's approach + var commands strings.Builder + commands.WriteString("d.init\n") + commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomains, arraySymbol, domains)) + commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomainsNoSearch, digitSymbol, noSearch)) + commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerAddresses, arraySymbol, dnsServer.String())) + commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerPort, digitSymbol, strconv.Itoa(port))) + commands.WriteString(fmt.Sprintf("set %s\n", state)) + + if _, err := d.runScutil(commands.String()); err != nil { + return fmt.Errorf("applying state for domains %s, error: %w", domains, err) + } + + logger.Info("Added DNS override with server %s:%d for domains: %s", dnsServer.String(), port, domains) + return nil +} + +// removeKey removes a DNS configuration key +func (d *DarwinDNSConfigurator) removeKey(key string) error { + cmd := fmt.Sprintf("remove %s\n", key) + + if _, err := d.runScutil(cmd); err != nil { + return fmt.Errorf("remove key: %w", err) + } + + delete(d.createdKeys, key) + return nil +} + +// getPrimaryServiceKey gets the primary network service key +func (d *DarwinDNSConfigurator) getPrimaryServiceKey() (string, error) { + cmd := fmt.Sprintf("show %s\n", globalIPv4State) + + output, err := d.runScutil(cmd) + if err != nil { + return "", fmt.Errorf("run scutil: %w", err) + } + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "PrimaryService") { + parts := strings.Split(line, ":") + if len(parts) >= 2 { + return strings.TrimSpace(parts[1]), nil + } + } + } + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("scan output: %w", err) + } + + return "", fmt.Errorf("primary service not found") +} + +// parseServerAddresses parses DNS server addresses from scutil output +func (d *DarwinDNSConfigurator) parseServerAddresses(output []byte) []netip.Addr { + var servers []netip.Addr + inServerArray := false + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.HasPrefix(line, "ServerAddresses : {") { + inServerArray = true + continue + } + + if line == "}" { + inServerArray = false + continue + } + + if inServerArray { + // Line format: "0 : 8.8.8.8" + parts := strings.Split(line, " : ") + if len(parts) >= 2 { + if addr, err := netip.ParseAddr(parts[1]); err == nil { + servers = append(servers, addr) + } + } + } + } + + return servers +} + +// flushDNSCache flushes the system DNS cache +func (d *DarwinDNSConfigurator) flushDNSCache() error { + logger.Debug("Flushing dscacheutil cache") + cmd := exec.Command(dscacheutilPath, "-flushcache") + if err := cmd.Run(); err != nil { + return fmt.Errorf("flush cache: %w", err) + } + + logger.Debug("Flushing mDNSResponder cache") + + cmd = exec.Command("killall", "-HUP", "mDNSResponder") + if err := cmd.Run(); err != nil { + // Non-fatal, mDNSResponder might not be running + return nil + } + + return nil +} + +// runScutil executes an scutil command +func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) { + // Wrap commands with open/quit + wrapped := fmt.Sprintf("open\n%squit\n", commands) + + logger.Debug("Running scutil with commands:\n%s\n", wrapped) + + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader(wrapped) + + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("scutil command failed: %w, output: %s", err, output) + } + + logger.Debug("scutil output:\n%s\n", output) + + return output, nil +} diff --git a/dns/platform/detect_unix.go b/dns/platform/detect_unix.go new file mode 100644 index 0000000..87b7dc7 --- /dev/null +++ b/dns/platform/detect_unix.go @@ -0,0 +1,158 @@ +//go:build (linux && !android) || freebsd + +package dns + +import ( + "bufio" + "io" + "os" + "strings" + + "github.com/fosrl/newt/logger" +) + +const defaultResolvConfPath = "/etc/resolv.conf" + +// DNSManagerType represents the type of DNS manager detected +type DNSManagerType int + +const ( + // UnknownManager indicates we couldn't determine the DNS manager + UnknownManager DNSManagerType = iota + // SystemdResolvedManager indicates systemd-resolved is managing DNS + SystemdResolvedManager + // NetworkManagerManager indicates NetworkManager is managing DNS + NetworkManagerManager + // ResolvconfManager indicates resolvconf is managing DNS + ResolvconfManager + // FileManager indicates direct file management (no DNS manager) + FileManager +) + +// DetectDNSManagerFromFile reads /etc/resolv.conf to determine which DNS manager is in use +// This provides a hint based on comments in the file, similar to Netbird's approach +func DetectDNSManagerFromFile() DNSManagerType { + file, err := os.Open(defaultResolvConfPath) + if err != nil { + return UnknownManager + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + if len(text) == 0 { + continue + } + + // If we hit a non-comment line, default to file-based + if text[0] != '#' { + return FileManager + } + + // Check for DNS manager signatures in comments + if strings.Contains(text, "NetworkManager") { + return NetworkManagerManager + } + + if strings.Contains(text, "systemd-resolved") { + return SystemdResolvedManager + } + + if strings.Contains(text, "resolvconf") { + return ResolvconfManager + } + } + + if err := scanner.Err(); err != nil && err != io.EOF { + return UnknownManager + } + + // No indicators found, assume file-based management + return FileManager +} + +// String returns a human-readable name for the DNS manager type +func (d DNSManagerType) String() string { + switch d { + case SystemdResolvedManager: + return "systemd-resolved" + case NetworkManagerManager: + return "NetworkManager" + case ResolvconfManager: + return "resolvconf" + case FileManager: + return "file" + default: + return "unknown" + } +} + +// DetectDNSManager combines file detection with runtime availability checks +// to determine the best DNS configurator to use +func DetectDNSManager(interfaceName string) DNSManagerType { + // First check what the file suggests + fileHint := DetectDNSManagerFromFile() + + // Verify the hint with runtime checks + switch fileHint { + case SystemdResolvedManager: + // Verify systemd-resolved is actually running + if IsSystemdResolvedAvailable() { + return SystemdResolvedManager + } + logger.Warn("dns platform: Found systemd-resolved but it is not running. Falling back to file...") + os.Exit(0) + return FileManager + + case NetworkManagerManager: + // Verify NetworkManager is actually running + if IsNetworkManagerAvailable() { + // Check if NetworkManager is delegating to systemd-resolved + if !IsNetworkManagerDNSModeSupported() { + logger.Info("NetworkManager is delegating DNS to systemd-resolved, using systemd-resolved configurator") + if IsSystemdResolvedAvailable() { + return SystemdResolvedManager + } + } + return NetworkManagerManager + } + logger.Warn("dns platform: Found network manager but it is not running. Falling back to file...") + return FileManager + + case ResolvconfManager: + // Verify resolvconf is available + if IsResolvconfAvailable() { + return ResolvconfManager + } + // If resolvconf is mentioned but not available, fall back to file + return FileManager + + case FileManager: + // File suggests direct file management + // But we should still check if a manager is available that wasn't mentioned + if IsSystemdResolvedAvailable() && interfaceName != "" { + return SystemdResolvedManager + } + if IsNetworkManagerAvailable() && interfaceName != "" { + return NetworkManagerManager + } + if IsResolvconfAvailable() && interfaceName != "" { + return ResolvconfManager + } + return FileManager + + default: + // Unknown - do runtime detection + if IsSystemdResolvedAvailable() && interfaceName != "" { + return SystemdResolvedManager + } + if IsNetworkManagerAvailable() && interfaceName != "" { + return NetworkManagerManager + } + if IsResolvconfAvailable() && interfaceName != "" { + return ResolvconfManager + } + return FileManager + } +} diff --git a/dns/platform/file.go b/dns/platform/file.go new file mode 100644 index 0000000..8f6f766 --- /dev/null +++ b/dns/platform/file.go @@ -0,0 +1,192 @@ +//go:build (linux && !android) || freebsd + +package dns + +import ( + "fmt" + "net/netip" + "os" + "strings" +) + +const ( + resolvConfPath = "/etc/resolv.conf" + resolvConfBackupPath = "/etc/resolv.conf.olm.backup" + resolvConfHeader = "# Generated by Olm DNS Manager\n# Original file backed up to " + resolvConfBackupPath + "\n\n" +) + +// FileDNSConfigurator manages DNS settings by directly modifying /etc/resolv.conf +type FileDNSConfigurator struct { + originalState *DNSState +} + +// NewFileDNSConfigurator creates a new file-based DNS configurator +func NewFileDNSConfigurator() (*FileDNSConfigurator, error) { + return &FileDNSConfigurator{}, nil +} + +// Name returns the configurator name +func (f *FileDNSConfigurator) Name() string { + return "file-resolv.conf" +} + +// SetDNS sets the DNS servers and returns the original servers +func (f *FileDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) { + // Get current DNS settings before overriding + originalServers, err := f.GetCurrentDNS() + if err != nil { + return nil, fmt.Errorf("get current DNS: %w", err) + } + + // Backup original resolv.conf if not already backed up + if !f.isBackupExists() { + if err := f.backupResolvConf(); err != nil { + return nil, fmt.Errorf("backup resolv.conf: %w", err) + } + } + + // Store original state + f.originalState = &DNSState{ + OriginalServers: originalServers, + ConfiguratorName: f.Name(), + } + + // Write new resolv.conf + if err := f.writeResolvConf(servers); err != nil { + return nil, fmt.Errorf("write resolv.conf: %w", err) + } + + return originalServers, nil +} + +// RestoreDNS restores the original DNS configuration +func (f *FileDNSConfigurator) RestoreDNS() error { + if !f.isBackupExists() { + return fmt.Errorf("no backup file exists") + } + + // Copy backup back to original location + if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil { + return fmt.Errorf("restore from backup: %w", err) + } + + // Remove backup file + if err := os.Remove(resolvConfBackupPath); err != nil { + return fmt.Errorf("remove backup file: %w", err) + } + + return nil +} + +// GetCurrentDNS returns the currently configured DNS servers +func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) { + content, err := os.ReadFile(resolvConfPath) + if err != nil { + return nil, fmt.Errorf("read resolv.conf: %w", err) + } + + return f.parseNameservers(string(content)), nil +} + +// backupResolvConf creates a backup of the current resolv.conf +func (f *FileDNSConfigurator) backupResolvConf() error { + // Get file info for permissions + info, err := os.Stat(resolvConfPath) + if err != nil { + return fmt.Errorf("stat resolv.conf: %w", err) + } + + if err := copyFile(resolvConfPath, resolvConfBackupPath); err != nil { + return fmt.Errorf("copy file: %w", err) + } + + // Preserve permissions + if err := os.Chmod(resolvConfBackupPath, info.Mode()); err != nil { + return fmt.Errorf("chmod backup: %w", err) + } + + return nil +} + +// writeResolvConf writes a new resolv.conf with the specified DNS servers +func (f *FileDNSConfigurator) writeResolvConf(servers []netip.Addr) error { + if len(servers) == 0 { + return fmt.Errorf("no DNS servers provided") + } + + // Get file info for permissions + info, err := os.Stat(resolvConfPath) + if err != nil { + return fmt.Errorf("stat resolv.conf: %w", err) + } + + var content strings.Builder + content.WriteString(resolvConfHeader) + + // Write nameservers + for _, server := range servers { + content.WriteString("nameserver ") + content.WriteString(server.String()) + content.WriteString("\n") + } + + // Write the file + if err := os.WriteFile(resolvConfPath, []byte(content.String()), info.Mode()); err != nil { + return fmt.Errorf("write resolv.conf: %w", err) + } + + return nil +} + +// isBackupExists checks if a backup file exists +func (f *FileDNSConfigurator) isBackupExists() bool { + _, err := os.Stat(resolvConfBackupPath) + return err == nil +} + +// parseNameservers extracts nameserver entries from resolv.conf content +func (f *FileDNSConfigurator) parseNameservers(content string) []netip.Addr { + var servers []netip.Addr + + lines := strings.Split(content, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + + // Skip comments and empty lines + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Look for nameserver lines + if strings.HasPrefix(line, "nameserver") { + fields := strings.Fields(line) + if len(fields) >= 2 { + if addr, err := netip.ParseAddr(fields[1]); err == nil { + servers = append(servers, addr) + } + } + } + } + + return servers +} + +// copyFile copies a file from src to dst +func copyFile(src, dst string) error { + content, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("read source: %w", err) + } + + // Get source file permissions + info, err := os.Stat(src) + if err != nil { + return fmt.Errorf("stat source: %w", err) + } + + if err := os.WriteFile(dst, content, info.Mode()); err != nil { + return fmt.Errorf("write destination: %w", err) + } + + return nil +} diff --git a/dns/platform/network_manager.go b/dns/platform/network_manager.go new file mode 100644 index 0000000..a88f5e9 --- /dev/null +++ b/dns/platform/network_manager.go @@ -0,0 +1,294 @@ +//go:build (linux && !android) || freebsd + +package dns + +import ( + "context" + "errors" + "fmt" + "net/netip" + "os" + "strings" + "time" + + dbus "github.com/godbus/dbus/v5" +) + +const ( + // NetworkManager D-Bus constants + networkManagerDest = "org.freedesktop.NetworkManager" + networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager" + networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager" + networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager" + networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode" + networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version" + + // NetworkManager dispatcher script path + networkManagerDispatcherDir = "/etc/NetworkManager/dispatcher.d" + networkManagerConfDir = "/etc/NetworkManager/conf.d" + networkManagerDNSConfFile = "olm-dns.conf" + networkManagerDispatcherFile = "01-olm-dns" +) + +// NetworkManagerDNSConfigurator manages DNS settings using NetworkManager configuration files +// This approach works with unmanaged interfaces by modifying NetworkManager's global DNS settings +type NetworkManagerDNSConfigurator struct { + ifaceName string + originalState *DNSState + confPath string + dispatchPath string +} + +// NewNetworkManagerDNSConfigurator creates a new NetworkManager DNS configurator +func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfigurator, error) { + if ifaceName == "" { + return nil, fmt.Errorf("interface name is required") + } + + // Check that NetworkManager conf.d directory exists + if _, err := os.Stat(networkManagerConfDir); os.IsNotExist(err) { + return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir) + } + + return &NetworkManagerDNSConfigurator{ + ifaceName: ifaceName, + confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile, + dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile, + }, nil +} + +// Name returns the configurator name +func (n *NetworkManagerDNSConfigurator) Name() string { + return "network-manager" +} + +// SetDNS sets the DNS servers and returns the original servers +func (n *NetworkManagerDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) { + // Get current DNS settings before overriding + originalServers, err := n.GetCurrentDNS() + if err != nil { + // If we can't get current DNS, proceed anyway + originalServers = []netip.Addr{} + } + + // Store original state + n.originalState = &DNSState{ + OriginalServers: originalServers, + ConfiguratorName: n.Name(), + } + + // Apply new DNS servers + if err := n.applyDNSServers(servers); err != nil { + return nil, fmt.Errorf("apply DNS servers: %w", err) + } + + return originalServers, nil +} + +// RestoreDNS restores the original DNS configuration +func (n *NetworkManagerDNSConfigurator) RestoreDNS() error { + // Remove our configuration file + if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove DNS config file: %w", err) + } + + // Reload NetworkManager to apply the change + if err := n.reloadNetworkManager(); err != nil { + return fmt.Errorf("reload NetworkManager: %w", err) + } + + return nil +} + +// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf +func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) { + content, err := os.ReadFile("/etc/resolv.conf") + if err != nil { + return nil, fmt.Errorf("read resolv.conf: %w", err) + } + + var servers []netip.Addr + lines := strings.Split(string(content), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "nameserver") { + fields := strings.Fields(line) + if len(fields) >= 2 { + if addr, err := netip.ParseAddr(fields[1]); err == nil { + servers = append(servers, addr) + } + } + } + } + + return servers, nil +} + +// applyDNSServers applies DNS server configuration via NetworkManager config file +func (n *NetworkManagerDNSConfigurator) applyDNSServers(servers []netip.Addr) error { + if len(servers) == 0 { + return fmt.Errorf("no DNS servers provided") + } + + // Build DNS server list + var dnsServers []string + for _, server := range servers { + dnsServers = append(dnsServers, server.String()) + } + + // Create NetworkManager configuration file that sets global DNS + // This overrides DNS for all connections + configContent := fmt.Sprintf(`# Generated by Olm DNS Manager - DO NOT EDIT +# This file configures NetworkManager to use Olm's DNS proxy + +[global-dns-domain-*] +servers=%s +`, strings.Join(dnsServers, ",")) + + // Write the configuration file + if err := os.WriteFile(n.confPath, []byte(configContent), 0644); err != nil { + return fmt.Errorf("write DNS config file: %w", err) + } + + // Reload NetworkManager to apply the new configuration + if err := n.reloadNetworkManager(); err != nil { + // Try to clean up + os.Remove(n.confPath) + return fmt.Errorf("reload NetworkManager: %w", err) + } + + return nil +} + +// reloadNetworkManager tells NetworkManager to reload its configuration +func (n *NetworkManagerDNSConfigurator) reloadNetworkManager() error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Call Reload method with flags=0 (reload everything) + // See: https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.html#gdbus-method-org-freedesktop-NetworkManager.Reload + err = obj.CallWithContext(ctx, networkManagerDest+".Reload", 0, uint32(0)).Store() + if err != nil { + return fmt.Errorf("call Reload: %w", err) + } + + return nil +} + +// IsNetworkManagerAvailable checks if NetworkManager is available and responsive +func IsNetworkManagerAvailable() bool { + conn, err := dbus.SystemBus() + if err != nil { + return false + } + defer conn.Close() + + obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Try to ping NetworkManager + if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil { + return false + } + + return true +} + +// IsNetworkManagerDNSModeSupported checks if NetworkManager's DNS mode is one we can work with +// Some DNS modes delegate to other systems (like systemd-resolved) which we should use directly +func IsNetworkManagerDNSModeSupported() bool { + conn, err := dbus.SystemBus() + if err != nil { + return false + } + defer conn.Close() + + obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode) + + modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty) + if err != nil { + // If we can't get the mode, assume it's not supported + return false + } + + mode, ok := modeVariant.Value().(string) + if !ok { + return false + } + + // If NetworkManager is delegating DNS to systemd-resolved, we should use + // systemd-resolved directly for better control + switch mode { + case "systemd-resolved": + // NetworkManager is delegating to systemd-resolved + // We should use systemd-resolved configurator instead + return false + case "dnsmasq", "unbound": + // NetworkManager is using a local resolver that it controls + // We can configure DNS through NetworkManager + return true + case "default", "none", "": + // NetworkManager is managing DNS directly or not at all + // We can configure DNS through NetworkManager + return true + default: + // Unknown mode, try to use it + return true + } +} + +// GetNetworkManagerDNSMode returns the current DNS mode of NetworkManager +func GetNetworkManagerDNSMode() (string, error) { + conn, err := dbus.SystemBus() + if err != nil { + return "", fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode) + + modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty) + if err != nil { + return "", fmt.Errorf("get DNS mode property: %w", err) + } + + mode, ok := modeVariant.Value().(string) + if !ok { + return "", errors.New("DNS mode is not a string") + } + + return mode, nil +} + +// GetNetworkManagerVersion returns the version of NetworkManager +func GetNetworkManagerVersion() (string, error) { + conn, err := dbus.SystemBus() + if err != nil { + return "", fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode) + + versionVariant, err := obj.GetProperty(networkManagerDbusVersionProperty) + if err != nil { + return "", fmt.Errorf("get version property: %w", err) + } + + version, ok := versionVariant.Value().(string) + if !ok { + return "", errors.New("version is not a string") + } + + return version, nil +} diff --git a/dns/platform/resolvconf.go b/dns/platform/resolvconf.go new file mode 100644 index 0000000..4202c4c --- /dev/null +++ b/dns/platform/resolvconf.go @@ -0,0 +1,192 @@ +//go:build (linux && !android) || freebsd + +package dns + +import ( + "bytes" + "fmt" + "net/netip" + "os/exec" + "strings" +) + +const resolvconfCommand = "resolvconf" + +// ResolvconfDNSConfigurator manages DNS settings using the resolvconf utility +type ResolvconfDNSConfigurator struct { + ifaceName string + implType string + originalState *DNSState +} + +// NewResolvconfDNSConfigurator creates a new resolvconf DNS configurator +func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator, error) { + if ifaceName == "" { + return nil, fmt.Errorf("interface name is required") + } + + // Detect resolvconf implementation type + implType, err := detectResolvconfType() + if err != nil { + return nil, fmt.Errorf("detect resolvconf type: %w", err) + } + + return &ResolvconfDNSConfigurator{ + ifaceName: ifaceName, + implType: implType, + }, nil +} + +// Name returns the configurator name +func (r *ResolvconfDNSConfigurator) Name() string { + return fmt.Sprintf("resolvconf-%s", r.implType) +} + +// SetDNS sets the DNS servers and returns the original servers +func (r *ResolvconfDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) { + // Get current DNS settings before overriding + originalServers, err := r.GetCurrentDNS() + if err != nil { + // If we can't get current DNS, proceed anyway + originalServers = []netip.Addr{} + } + + // Store original state + r.originalState = &DNSState{ + OriginalServers: originalServers, + ConfiguratorName: r.Name(), + } + + // Apply new DNS servers + if err := r.applyDNSServers(servers); err != nil { + return nil, fmt.Errorf("apply DNS servers: %w", err) + } + + return originalServers, nil +} + +// RestoreDNS restores the original DNS configuration +func (r *ResolvconfDNSConfigurator) RestoreDNS() error { + var cmd *exec.Cmd + + switch r.implType { + case "openresolv": + // Force delete with -f + cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) + default: + cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName) + } + + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("delete resolvconf config: %w, output: %s", err, out) + } + + return nil +} + +// GetCurrentDNS returns the currently configured DNS servers +func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) { + // resolvconf doesn't provide a direct way to query per-interface DNS + // We can try to read /etc/resolv.conf but it's merged from all sources + content, err := exec.Command(resolvconfCommand, "-l").CombinedOutput() + if err != nil { + // Fall back to reading resolv.conf + return readResolvConfServers() + } + + // Parse the output (format varies by implementation) + return parseResolvconfOutput(string(content)), nil +} + +// applyDNSServers applies DNS server configuration via resolvconf +func (r *ResolvconfDNSConfigurator) applyDNSServers(servers []netip.Addr) error { + if len(servers) == 0 { + return fmt.Errorf("no DNS servers provided") + } + + // Build resolv.conf content + var content bytes.Buffer + content.WriteString("# Generated by Olm DNS Manager\n\n") + + for _, server := range servers { + content.WriteString("nameserver ") + content.WriteString(server.String()) + content.WriteString("\n") + } + + // Apply via resolvconf + var cmd *exec.Cmd + switch r.implType { + case "openresolv": + // OpenResolv supports exclusive mode with -x + cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName) + default: + cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName) + } + + cmd.Stdin = &content + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("apply resolvconf config: %w, output: %s", err, out) + } + + return nil +} + +// detectResolvconfType detects which resolvconf implementation is being used +func detectResolvconfType() (string, error) { + cmd := exec.Command(resolvconfCommand, "--version") + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("detect resolvconf type: %w", err) + } + + if strings.Contains(string(out), "openresolv") { + return "openresolv", nil + } + + return "resolvconf", nil +} + +// parseResolvconfOutput parses resolvconf -l output for DNS servers +func parseResolvconfOutput(output string) []netip.Addr { + var servers []netip.Addr + + lines := strings.Split(output, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + + // Skip comments and empty lines + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Look for nameserver lines + if strings.HasPrefix(line, "nameserver") { + fields := strings.Fields(line) + if len(fields) >= 2 { + if addr, err := netip.ParseAddr(fields[1]); err == nil { + servers = append(servers, addr) + } + } + } + } + + return servers +} + +// readResolvConfServers reads DNS servers from /etc/resolv.conf +func readResolvConfServers() ([]netip.Addr, error) { + cmd := exec.Command("cat", "/etc/resolv.conf") + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("read resolv.conf: %w", err) + } + + return parseResolvconfOutput(string(out)), nil +} + +// IsResolvconfAvailable checks if resolvconf is available +func IsResolvconfAvailable() bool { + cmd := exec.Command(resolvconfCommand, "--version") + return cmd.Run() == nil +} diff --git a/dns/platform/systemd.go b/dns/platform/systemd.go new file mode 100644 index 0000000..61f9ca6 --- /dev/null +++ b/dns/platform/systemd.go @@ -0,0 +1,286 @@ +//go:build linux && !android + +package dns + +import ( + "context" + "fmt" + "net" + "net/netip" + "time" + + dbus "github.com/godbus/dbus/v5" + "golang.org/x/sys/unix" +) + +const ( + systemdResolvedDest = "org.freedesktop.resolve1" + systemdDbusObjectNode = "/org/freedesktop/resolve1" + systemdDbusManagerIface = "org.freedesktop.resolve1.Manager" + systemdDbusGetLinkMethod = systemdDbusManagerIface + ".GetLink" + systemdDbusFlushCachesMethod = systemdDbusManagerIface + ".FlushCaches" + systemdDbusLinkInterface = "org.freedesktop.resolve1.Link" + systemdDbusSetDNSMethod = systemdDbusLinkInterface + ".SetDNS" + systemdDbusSetDefaultRouteMethod = systemdDbusLinkInterface + ".SetDefaultRoute" + systemdDbusSetDomainsMethod = systemdDbusLinkInterface + ".SetDomains" + systemdDbusSetDNSSECMethod = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethod = systemdDbusLinkInterface + ".SetDNSOverTLS" + systemdDbusRevertMethod = systemdDbusLinkInterface + ".Revert" + + // RootZone is the root DNS zone that matches all queries + RootZone = "." +) + +// systemdDbusDNSInput maps to (iay) dbus input for SetDNS method +type systemdDbusDNSInput struct { + Family int32 + Address []byte +} + +// systemdDbusDomainsInput maps to (sb) dbus input for SetDomains method +type systemdDbusDomainsInput struct { + Domain string + MatchOnly bool +} + +// SystemdResolvedDNSConfigurator manages DNS settings using systemd-resolved D-Bus API +type SystemdResolvedDNSConfigurator struct { + ifaceName string + dbusLinkObject dbus.ObjectPath + originalState *DNSState +} + +// NewSystemdResolvedDNSConfigurator creates a new systemd-resolved DNS configurator +func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSConfigurator, error) { + // Get network interface + iface, err := net.InterfaceByName(ifaceName) + if err != nil { + return nil, fmt.Errorf("get interface: %w", err) + } + + // Connect to D-Bus + conn, err := dbus.SystemBus() + if err != nil { + return nil, fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode) + + // Get the link object for this interface + var linkPath string + if err := obj.Call(systemdDbusGetLinkMethod, 0, iface.Index).Store(&linkPath); err != nil { + return nil, fmt.Errorf("get link: %w", err) + } + + return &SystemdResolvedDNSConfigurator{ + ifaceName: ifaceName, + dbusLinkObject: dbus.ObjectPath(linkPath), + }, nil +} + +// Name returns the configurator name +func (s *SystemdResolvedDNSConfigurator) Name() string { + return "systemd-resolved" +} + +// SetDNS sets the DNS servers and returns the original servers +func (s *SystemdResolvedDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) { + // Get current DNS settings before overriding + originalServers, err := s.GetCurrentDNS() + if err != nil { + // If we can't get current DNS, proceed anyway + originalServers = []netip.Addr{} + } + + // Store original state + s.originalState = &DNSState{ + OriginalServers: originalServers, + ConfiguratorName: s.Name(), + } + + // Apply new DNS servers + if err := s.applyDNSServers(servers); err != nil { + return nil, fmt.Errorf("apply DNS servers: %w", err) + } + + return originalServers, nil +} + +// RestoreDNS restores the original DNS configuration +func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error { + // Call Revert method to restore systemd-resolved defaults + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(systemdResolvedDest, s.dbusLinkObject) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := obj.CallWithContext(ctx, systemdDbusRevertMethod, 0).Store(); err != nil { + return fmt.Errorf("revert DNS settings: %w", err) + } + + // Flush DNS cache after reverting + if err := s.flushDNSCache(); err != nil { + fmt.Printf("warning: failed to flush DNS cache: %v\n", err) + } + + return nil +} + +// GetCurrentDNS returns the currently configured DNS servers +// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus +// This is a placeholder that returns an empty list +func (s *SystemdResolvedDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) { + // systemd-resolved's D-Bus API doesn't have a simple way to query current DNS servers + // We would need to parse resolvectl status output or read from /run/systemd/resolve/ + // For now, return empty list + return []netip.Addr{}, nil +} + +// applyDNSServers applies DNS server configuration via systemd-resolved +func (s *SystemdResolvedDNSConfigurator) applyDNSServers(servers []netip.Addr) error { + if len(servers) == 0 { + return fmt.Errorf("no DNS servers provided") + } + + // Convert servers to systemd-resolved format + var dnsInputs []systemdDbusDNSInput + for _, server := range servers { + family := unix.AF_INET + if server.Is6() { + family = unix.AF_INET6 + } + + dnsInputs = append(dnsInputs, systemdDbusDNSInput{ + Family: int32(family), + Address: server.AsSlice(), + }) + } + + // Connect to D-Bus + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(systemdResolvedDest, s.dbusLinkObject) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Call SetDNS method to set the DNS servers + if err := obj.CallWithContext(ctx, systemdDbusSetDNSMethod, 0, dnsInputs).Store(); err != nil { + return fmt.Errorf("set DNS servers: %w", err) + } + + // Set this interface as the default route for DNS + // This ensures all DNS queries prefer this interface + if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethod, true); err != nil { + return fmt.Errorf("set default route: %w", err) + } + + // Set the root zone "." as a match-only domain + // This captures ALL DNS queries and routes them through this interface + domainsInput := []systemdDbusDomainsInput{ + { + Domain: RootZone, + MatchOnly: true, + }, + } + if err := s.callLinkMethod(systemdDbusSetDomainsMethod, domainsInput); err != nil { + return fmt.Errorf("set domains: %w", err) + } + + // Disable DNSSEC - we don't support it and it may be enabled by default + if err := s.callLinkMethod(systemdDbusSetDNSSECMethod, "no"); err != nil { + // Log warning but don't fail - this is optional + fmt.Printf("warning: failed to disable DNSSEC: %v\n", err) + } + + // Disable DNSOverTLS - we don't support it and it may be enabled by default + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethod, "no"); err != nil { + // Log warning but don't fail - this is optional + fmt.Printf("warning: failed to disable DNSOverTLS: %v\n", err) + } + + // Flush DNS cache to ensure new settings take effect immediately + if err := s.flushDNSCache(); err != nil { + fmt.Printf("warning: failed to flush DNS cache: %v\n", err) + } + + return nil +} + +// callLinkMethod is a helper to call methods on the link object +func (s *SystemdResolvedDNSConfigurator) callLinkMethod(method string, value any) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(systemdResolvedDest, s.dbusLinkObject) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if value != nil { + if err := obj.CallWithContext(ctx, method, 0, value).Store(); err != nil { + return fmt.Errorf("call %s: %w", method, err) + } + } else { + if err := obj.CallWithContext(ctx, method, 0).Store(); err != nil { + return fmt.Errorf("call %s: %w", method, err) + } + } + + return nil +} + +// flushDNSCache flushes the systemd-resolved DNS cache +func (s *SystemdResolvedDNSConfigurator) flushDNSCache() error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("connect to system bus: %w", err) + } + defer conn.Close() + + obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, 0).Store(); err != nil { + return fmt.Errorf("flush caches: %w", err) + } + + return nil +} + +// IsSystemdResolvedAvailable checks if systemd-resolved is available and responsive +func IsSystemdResolvedAvailable() bool { + conn, err := dbus.SystemBus() + if err != nil { + return false + } + defer conn.Close() + + obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Try to ping systemd-resolved + if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil { + return false + } + + return true +} diff --git a/dns/platform/types.go b/dns/platform/types.go new file mode 100644 index 0000000..471ba29 --- /dev/null +++ b/dns/platform/types.go @@ -0,0 +1,41 @@ +package dns + +import "net/netip" + +// DNSConfigurator provides an interface for managing system DNS settings +// across different platforms and implementations +type DNSConfigurator interface { + // SetDNS overrides the system DNS servers with the specified ones + // Returns the original DNS servers that were replaced + SetDNS(servers []netip.Addr) ([]netip.Addr, error) + + // RestoreDNS restores the original DNS servers + RestoreDNS() error + + // GetCurrentDNS returns the currently configured DNS servers + GetCurrentDNS() ([]netip.Addr, error) + + // Name returns the name of this configurator implementation + Name() string +} + +// DNSConfig contains the configuration for DNS override +type DNSConfig struct { + // Servers is the list of DNS servers to use + Servers []netip.Addr + + // SearchDomains is an optional list of search domains + SearchDomains []string +} + +// DNSState represents the saved state of DNS configuration +type DNSState struct { + // OriginalServers are the DNS servers before override + OriginalServers []netip.Addr + + // OriginalSearchDomains are the search domains before override + OriginalSearchDomains []string + + // ConfiguratorName is the name of the configurator that saved this state + ConfiguratorName string +} diff --git a/dns/platform/windows.go b/dns/platform/windows.go new file mode 100644 index 0000000..f4c5896 --- /dev/null +++ b/dns/platform/windows.go @@ -0,0 +1,343 @@ +//go:build windows + +package dns + +import ( + "errors" + "fmt" + "io" + "net" + "net/netip" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +var ( + dnsapi = syscall.NewLazyDLL("dnsapi.dll") + dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache") +) + +const ( + interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` + interfaceConfigNameServer = "NameServer" + interfaceConfigDhcpNameServer = "DhcpNameServer" +) + +// WindowsDNSConfigurator manages DNS settings on Windows using the registry +type WindowsDNSConfigurator struct { + guid string + originalState *DNSState +} + +// NewWindowsDNSConfigurator creates a new Windows DNS configurator +// Accepts an interface name and extracts the GUID internally +func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) { + if interfaceName == "" { + return nil, fmt.Errorf("interface name is required") + } + + guid, err := getInterfaceGUIDString(interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get interface GUID: %w", err) + } + + return &WindowsDNSConfigurator{ + guid: guid, + }, nil +} + +// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string +// This is an internal function for use by DetectBestConfigurator +func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) { + if guid == "" { + return nil, fmt.Errorf("interface GUID is required") + } + + return &WindowsDNSConfigurator{ + guid: guid, + }, nil +} + +// Name returns the configurator name +func (w *WindowsDNSConfigurator) Name() string { + return "windows-registry" +} + +// SetDNS sets the DNS servers and returns the original servers +func (w *WindowsDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) { + // Get current DNS settings before overriding + originalServers, err := w.GetCurrentDNS() + if err != nil { + return nil, fmt.Errorf("get current DNS: %w", err) + } + + // Store original state + w.originalState = &DNSState{ + OriginalServers: originalServers, + ConfiguratorName: w.Name(), + } + + // Set new DNS servers + if err := w.setDNSServers(servers); err != nil { + return nil, fmt.Errorf("set DNS servers: %w", err) + } + + // Flush DNS cache + if err := w.flushDNSCache(); err != nil { + // Non-fatal, just log + fmt.Printf("warning: failed to flush DNS cache: %v\n", err) + } + + return originalServers, nil +} + +// RestoreDNS restores the original DNS configuration +func (w *WindowsDNSConfigurator) RestoreDNS() error { + if w.originalState == nil { + return fmt.Errorf("no original state to restore") + } + + // Clear the static DNS setting + if err := w.clearDNSServers(); err != nil { + return fmt.Errorf("clear DNS servers: %w", err) + } + + // Flush DNS cache + if err := w.flushDNSCache(); err != nil { + fmt.Printf("warning: failed to flush DNS cache: %v\n", err) + } + + return nil +} + +// GetCurrentDNS returns the currently configured DNS servers +func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) { + regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE) + if err != nil { + return nil, fmt.Errorf("get interface registry key: %w", err) + } + defer closeKey(regKey) + + // Try to get static DNS first + nameServer, _, err := regKey.GetStringValue(interfaceConfigNameServer) + if err == nil && nameServer != "" { + return w.parseServerList(nameServer), nil + } + + // Fall back to DHCP DNS + dhcpNameServer, _, err := regKey.GetStringValue(interfaceConfigDhcpNameServer) + if err == nil && dhcpNameServer != "" { + return w.parseServerList(dhcpNameServer), nil + } + + return []netip.Addr{}, nil +} + +// setDNSServers sets the DNS servers in the registry +func (w *WindowsDNSConfigurator) setDNSServers(servers []netip.Addr) error { + if len(servers) == 0 { + return fmt.Errorf("no DNS servers provided") + } + + regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE) + if err != nil { + return fmt.Errorf("get interface registry key: %w", err) + } + defer closeKey(regKey) + + // Build comma-separated or space-separated list of servers + var serverList string + for i, server := range servers { + if i > 0 { + serverList += "," + } + serverList += server.String() + } + + if err := regKey.SetStringValue(interfaceConfigNameServer, serverList); err != nil { + return fmt.Errorf("set NameServer: %w", err) + } + + return nil +} + +// clearDNSServers clears the static DNS server setting +func (w *WindowsDNSConfigurator) clearDNSServers() error { + regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE) + if err != nil { + return fmt.Errorf("get interface registry key: %w", err) + } + defer closeKey(regKey) + + // Set empty string to revert to DHCP + if err := regKey.SetStringValue(interfaceConfigNameServer, ""); err != nil { + return fmt.Errorf("clear NameServer: %w", err) + } + + return nil +} + +// getInterfaceRegistryKey opens the registry key for the network interface +func (w *WindowsDNSConfigurator) getInterfaceRegistryKey(access uint32) (registry.Key, error) { + regKeyPath := interfaceConfigPath + `\` + w.guid + + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, access) + if err != nil { + return 0, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err) + } + + return regKey, nil +} + +// parseServerList parses a comma or space-separated list of DNS servers +func (w *WindowsDNSConfigurator) parseServerList(serverList string) []netip.Addr { + var servers []netip.Addr + + // Split by comma or space + parts := splitByDelimiters(serverList, []rune{',', ' '}) + + for _, part := range parts { + if addr, err := netip.ParseAddr(part); err == nil { + servers = append(servers, addr) + } + } + + return servers +} + +// flushDNSCache flushes the Windows DNS resolver cache +func (w *WindowsDNSConfigurator) flushDNSCache() error { + // dnsFlushResolverCacheFn.Call() may panic if the func is not found + defer func() { + if rec := recover(); rec != nil { + fmt.Printf("warning: DnsFlushResolverCache panicked: %v\n", rec) + } + }() + + ret, _, err := dnsFlushResolverCacheFn.Call() + if ret == 0 { + if err != nil && !errors.Is(err, syscall.Errno(0)) { + return fmt.Errorf("DnsFlushResolverCache failed: %w", err) + } + return fmt.Errorf("DnsFlushResolverCache failed") + } + + return nil +} + +// splitByDelimiters splits a string by multiple delimiters +func splitByDelimiters(s string, delimiters []rune) []string { + var result []string + var current []rune + + for _, char := range s { + isDelimiter := false + for _, delim := range delimiters { + if char == delim { + isDelimiter = true + break + } + } + + if isDelimiter { + if len(current) > 0 { + result = append(result, string(current)) + current = []rune{} + } + } else { + current = append(current, char) + } + } + + if len(current) > 0 { + result = append(result, string(current)) + } + + return result +} + +// closeKey closes a registry key and logs errors +func closeKey(closer io.Closer) { + if err := closer.Close(); err != nil { + fmt.Printf("warning: failed to close registry key: %v\n", err) + } +} + +// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface +// This is required for registry-based DNS configuration on Windows +func getInterfaceGUIDString(interfaceName string) (string, error) { + if interfaceName == "" { + return "", fmt.Errorf("interface name is required") + } + + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err) + } + + luid, err := indexToLUID(uint32(iface.Index)) + if err != nil { + return "", fmt.Errorf("failed to convert index to LUID: %w", err) + } + + // Convert LUID to GUID using Windows API + guid, err := luidToGUID(luid) + if err != nil { + return "", fmt.Errorf("failed to convert LUID to GUID: %w", err) + } + + return guid, nil +} + +// indexToLUID converts a Windows interface index to a LUID +func indexToLUID(index uint32) (uint64, error) { + var luid uint64 + + // Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function + iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll") + convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid") + + // Call the Windows API + ret, _, err := convertInterfaceIndexToLuid.Call( + uintptr(index), + uintptr(unsafe.Pointer(&luid)), + ) + + if ret != 0 { + return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err) + } + + return luid, nil +} + +// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string +// using the Windows ConvertInterface* APIs +func luidToGUID(luid uint64) (string, error) { + var guid windows.GUID + + // Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function + iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll") + convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid") + + // Call the Windows API + // NET_LUID is a 64-bit value on Windows + ret, _, err := convertLuidToGuid.Call( + uintptr(unsafe.Pointer(&luid)), + uintptr(unsafe.Pointer(&guid)), + ) + + if ret != 0 { + return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err) + } + + // Format the GUID as a string with curly braces + guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}", + guid.Data1, guid.Data2, guid.Data3, + guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3], + guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7]) + + return guidStr, nil +} diff --git a/go.mod b/go.mod index 8fa1cc1..bf4c165 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,31 @@ module github.com/fosrl/olm go 1.25 require ( - github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 + github.com/Microsoft/go-winio v0.6.2 + github.com/fosrl/newt v0.0.0 + github.com/godbus/dbus/v5 v5.2.0 github.com/gorilla/websocket v1.5.3 - github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.45.0 - golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 + github.com/miekg/dns v1.1.68 golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 + gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c software.sslmate.com/src/go-pkcs12 v0.6.0 ) require ( + github.com/google/btree v1.1.3 // indirect + github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect + golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.47.0 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect + golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index 9cdcf9d..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,46 @@ -github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o= -github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= +github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= +github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= -golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= -gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c= -gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs= +golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= +golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go deleted file mode 100644 index 4f57cca..0000000 --- a/httpserver/httpserver.go +++ /dev/null @@ -1,217 +0,0 @@ -package httpserver - -import ( - "encoding/json" - "fmt" - "net/http" - "sync" - "time" - - "github.com/fosrl/newt/logger" -) - -// ConnectionRequest defines the structure for an incoming connection request -type ConnectionRequest struct { - ID string `json:"id"` - Secret string `json:"secret"` - Endpoint string `json:"endpoint"` -} - -// PeerStatus represents the status of a peer connection -type PeerStatus struct { - SiteID int `json:"siteId"` - Connected bool `json:"connected"` - RTT time.Duration `json:"rtt"` - LastSeen time.Time `json:"lastSeen"` - Endpoint string `json:"endpoint,omitempty"` - IsRelay bool `json:"isRelay"` -} - -// StatusResponse is returned by the status endpoint -type StatusResponse struct { - Status string `json:"status"` - Connected bool `json:"connected"` - TunnelIP string `json:"tunnelIP,omitempty"` - Version string `json:"version,omitempty"` - PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` -} - -// HTTPServer represents the HTTP server and its state -type HTTPServer struct { - addr string - server *http.Server - connectionChan chan ConnectionRequest - statusMu sync.RWMutex - peerStatuses map[int]*PeerStatus - connectedAt time.Time - isConnected bool - tunnelIP string - version string -} - -// NewHTTPServer creates a new HTTP server -func NewHTTPServer(addr string) *HTTPServer { - s := &HTTPServer{ - addr: addr, - connectionChan: make(chan ConnectionRequest, 1), - peerStatuses: make(map[int]*PeerStatus), - } - - return s -} - -// Start starts the HTTP server -func (s *HTTPServer) Start() error { - mux := http.NewServeMux() - mux.HandleFunc("/connect", s.handleConnect) - mux.HandleFunc("/status", s.handleStatus) - - s.server = &http.Server{ - Addr: s.addr, - Handler: mux, - } - - logger.Info("Starting HTTP server on %s", s.addr) - go func() { - if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Error("HTTP server error: %v", err) - } - }() - - return nil -} - -// Stop stops the HTTP server -func (s *HTTPServer) Stop() error { - logger.Info("Stopping HTTP server") - return s.server.Close() -} - -// GetConnectionChannel returns the channel for receiving connection requests -func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest { - return s.connectionChan -} - -// UpdatePeerStatus updates the status of a peer including endpoint and relay info -func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - - status, exists := s.peerStatuses[siteID] - if !exists { - status = &PeerStatus{ - SiteID: siteID, - } - s.peerStatuses[siteID] = status - } - - status.Connected = connected - status.RTT = rtt - status.LastSeen = time.Now() - status.Endpoint = endpoint - status.IsRelay = isRelay -} - -// SetConnectionStatus sets the overall connection status -func (s *HTTPServer) SetConnectionStatus(isConnected bool) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - - s.isConnected = isConnected - - if isConnected { - s.connectedAt = time.Now() - } else { - // Clear peer statuses when disconnected - s.peerStatuses = make(map[int]*PeerStatus) - } -} - -// SetTunnelIP sets the tunnel IP address -func (s *HTTPServer) SetTunnelIP(tunnelIP string) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - s.tunnelIP = tunnelIP -} - -// SetVersion sets the olm version -func (s *HTTPServer) SetVersion(version string) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - s.version = version -} - -// UpdatePeerRelayStatus updates only the relay status of a peer -func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - - status, exists := s.peerStatuses[siteID] - if !exists { - status = &PeerStatus{ - SiteID: siteID, - } - s.peerStatuses[siteID] = status - } - - status.Endpoint = endpoint - status.IsRelay = isRelay -} - -// handleConnect handles the /connect endpoint -func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - var req ConnectionRequest - decoder := json.NewDecoder(r.Body) - if err := decoder.Decode(&req); err != nil { - http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) - return - } - - // Validate required fields - if req.ID == "" || req.Secret == "" || req.Endpoint == "" { - http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest) - return - } - - // Send the request to the main goroutine - s.connectionChan <- req - - // Return a success response - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(map[string]string{ - "status": "connection request accepted", - }) -} - -// handleStatus handles the /status endpoint -func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - s.statusMu.RLock() - defer s.statusMu.RUnlock() - - resp := StatusResponse{ - Connected: s.isConnected, - TunnelIP: s.tunnelIP, - Version: s.version, - PeerStatuses: s.peerStatuses, - } - - if s.isConnected { - resp.Status = "connected" - } else { - resp.Status = "disconnected" - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) -} diff --git a/main.go b/main.go index 339ea2f..f6c6973 100644 --- a/main.go +++ b/main.go @@ -2,56 +2,17 @@ package main import ( "context" - "encoding/json" "fmt" - "net" "os" "os/signal" "runtime" - "strconv" - "strings" "syscall" - "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/httpserver" - "github.com/fosrl/olm/peermonitor" - "github.com/fosrl/olm/websocket" - - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/fosrl/olm/olm" ) -// Helper function to format endpoints correctly -func formatEndpoint(endpoint string) string { - if endpoint == "" { - return "" - } - // Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080) - _, _, err := net.SplitHostPort(endpoint) - if err == nil { - return endpoint // Already valid, no change needed - } - - // If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it. - lastColon := strings.LastIndex(endpoint, ":") - if lastColon > 0 { // Ensure there is a colon and it's not the first character - hostPart := endpoint[:lastColon] - // Check if the host part is a literal IPv6 address - if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil { - // It is! Reformat it with brackets. - portPart := endpoint[lastColon+1:] - return fmt.Sprintf("[%s]:%s", hostPart, portPart) - } - } - - // If it's not the specific malformed case, return it as is. - return endpoint -} - func main() { // Check if we're running as a Windows service if isWindowsService() { @@ -193,17 +154,30 @@ func main() { } } + // Create a context that will be cancelled on interrupt signals + signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + // Create a separate context for programmatic shutdown (e.g., via API exit) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Run in console mode - runOlmMain(context.Background()) + runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:]) } -func runOlmMain(ctx context.Context) { - runOlmMainWithArgs(ctx, os.Args[1:]) -} +func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) { + // Setup Windows event logging if on Windows + if runtime.GOOS == "windows" { + setupWindowsEventLog() + } else { + // Initialize logger for non-Windows platforms + logger.Init(nil) + } -func runOlmMainWithArgs(ctx context.Context, args []string) { // Load configuration from file, env vars, and CLI args // Priority: CLI args > Env vars > Config file > Defaults + // Use the passed args parameter instead of os.Args[1:] to support Windows service mode config, showVersion, showConfig, err := LoadConfig(args) if err != nil { fmt.Printf("Failed to load configuration: %v\n", err) @@ -216,717 +190,75 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { os.Exit(0) } - // Extract commonly used values from config for convenience - var ( - endpoint = config.Endpoint - id = config.ID - secret = config.Secret - mtu = config.MTU - logLevel = config.LogLevel - interfaceName = config.InterfaceName - enableHTTP = config.EnableHTTP - httpAddr = config.HTTPAddr - pingInterval = config.PingIntervalDuration - pingTimeout = config.PingTimeoutDuration - doHolepunch = config.Holepunch - privateKey wgtypes.Key - connected bool - ) - - stopHolepunch = make(chan struct{}) - stopPing = make(chan struct{}) - - // Setup Windows event logging if on Windows - if runtime.GOOS == "windows" { - setupWindowsEventLog() - } else { - // Initialize logger for non-Windows platforms - logger.Init() - } - loggerLevel := parseLogLevel(logLevel) - logger.GetLogger().SetLevel(parseLogLevel(logLevel)) - olmVersion := "version_replaceme" if showVersion { fmt.Println("Olm version " + olmVersion) os.Exit(0) } - logger.Info("Olm version " + olmVersion) + logger.Info("Olm version %s", olmVersion) - if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil { + config.Version = olmVersion + + if err := SaveConfig(config); err != nil { + logger.Error("Failed to save full olm config: %v", err) + } else { + logger.Debug("Saved full olm config with all options") + } + + if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil { logger.Debug("Failed to check for updates: %v", err) } - // Log startup information - logger.Debug("Olm service starting...") - logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) - logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr) - - if doHolepunch { - logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") + // Create a new olm.Config struct and copy values from the main config + olmConfig := olm.GlobalConfig{ + LogLevel: config.LogLevel, + EnableAPI: config.EnableAPI, + HTTPAddr: config.HTTPAddr, + SocketPath: config.SocketPath, + Version: config.Version, + Agent: "Olm CLI", + OnExit: cancel, // Pass cancel function directly to trigger shutdown + OnTerminated: cancel, } - var httpServer *httpserver.HTTPServer - if enableHTTP { - httpServer = httpserver.NewHTTPServer(httpAddr) - httpServer.SetVersion(olmVersion) - if err := httpServer.Start(); err != nil { - logger.Fatal("Failed to start HTTP server: %v", err) - } - - // Use a goroutine to handle connection requests - go func() { - for req := range httpServer.GetConnectionChannel() { - logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - - // Set the connection parameters - id = req.ID - secret = req.Secret - endpoint = req.Endpoint - } - }() + olm.Init(ctx, olmConfig) + if err := olm.StartApi(); err != nil { + logger.Fatal("Failed to start API server: %v", err) } - // // Check if required parameters are missing and provide helpful guidance - // missingParams := []string{} - // if id == "" { - // missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)") - // } - // if secret == "" { - // missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)") - // } - // if endpoint == "" { - // missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)") - // } - - // if len(missingParams) > 0 { - // logger.Error("Missing required parameters: %v", missingParams) - // logger.Error("Either provide them as command line flags or set as environment variables") - // fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams) - // fmt.Printf("Please provide them as command line flags or set as environment variables\n") - // if !enableHTTP { - // logger.Error("HTTP server is disabled, cannot receive parameters via API") - // fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n") - // return - // } - // } - - // Create a new olm - olm, err := websocket.NewClient( - "olm", - id, // CLI arg takes precedence - secret, // CLI arg takes precedence - endpoint, - pingInterval, - pingTimeout, - ) - if err != nil { - logger.Fatal("Failed to create olm: %v", err) + if config.ID != "" && config.Secret != "" && config.Endpoint != "" { + tunnelConfig := olm.TunnelConfig{ + Endpoint: config.Endpoint, + ID: config.ID, + Secret: config.Secret, + UserToken: config.UserToken, + MTU: config.MTU, + DNS: config.DNS, + UpstreamDNS: config.UpstreamDNS, + InterfaceName: config.InterfaceName, + Holepunch: !config.DisableHolepunch, + TlsClientCert: config.TlsClientCert, + PingIntervalDuration: config.PingIntervalDuration, + PingTimeoutDuration: config.PingTimeoutDuration, + OrgID: config.OrgID, + OverrideDNS: config.OverrideDNS, + DisableRelay: config.DisableRelay, + EnableUAPI: true, + } + go olm.StartTunnel(tunnelConfig) + } else { + logger.Info("Incomplete tunnel configuration, not starting tunnel") } - // wait until we have a client id and secret and endpoint - waitCount := 0 - for id == "" || secret == "" || endpoint == "" { - select { - case <-ctx.Done(): - logger.Info("Context cancelled while waiting for credentials") - return - default: - missing := []string{} - if id == "" { - missing = append(missing, "id") - } - if secret == "" { - missing = append(missing, "secret") - } - if endpoint == "" { - missing = append(missing, "endpoint") - } - waitCount++ - if waitCount%10 == 1 { // Log every 10 seconds instead of every second - logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount) - } - time.Sleep(1 * time.Second) - } - } - - privateKey, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - - // Create TUN device and network stack - var dev *device.Device - var wgData WgData - var holePunchData HolePunchData - var uapiListener net.Listener - var tdev tun.Device - - sourcePort, err := FindAvailableUDPPort(49152, 65535) - if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - os.Exit(1) - } - - olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &holePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - // Create a new stopHolepunch channel for the new set of goroutines - stopHolepunch = make(chan struct{}) - - // Start a single hole punch goroutine for all exit nodes - logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes)) - go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort) - }) - - olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { - // THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY - logger.Debug("Received message: %v", msg.Data) - - type LegacyHolePunchData struct { - ServerPubKey string `json:"serverPubKey"` - Endpoint string `json:"endpoint"` - } - - var legacyHolePunchData LegacyHolePunchData - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - // Stop any existing hole punch goroutines by closing the current channel - select { - case <-stopHolepunch: - // Channel already closed - default: - close(stopHolepunch) - } - - // Create a new stopHolepunch channel for the new set of goroutines - stopHolepunch = make(chan struct{}) - - // Start hole punching for each exit node - logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) - go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey) - }) - - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - if connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - close(stopHolepunch) - - // wait 10 milliseconds to ensure the previous connection is closed - logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed") - time.Sleep(500 * time.Millisecond) - - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - tdev, err = func() (tun.Device, error) { - if runtime.GOOS == "darwin" { - interfaceName, err := findUnusedUTUN() - if err != nil { - return nil, err - } - return tun.CreateTUN(interfaceName, mtu) - } - if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" { - return createTUNFromFD(tunFdStr, mtu) - } - return tun.CreateTUN(interfaceName, mtu) - }() - - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } - - fileUAPI, err := func() (*os.File, error) { - if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" { - fd, err := strconv.ParseUint(uapiFdStr, 10, 32) - if err != nil { - return nil, err - } - return os.NewFile(uintptr(fd), ""), nil - } - return uapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return - } - - dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) - - uapiListener, err = uapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - - if err = dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - if err = ConfigureInterface(interfaceName, wgData); err != nil { - logger.Error("Failed to configure interface: %v", err) - } - if httpServer != nil { - httpServer.SetTunnelIP(wgData.TunnelIP) - } - - peerMonitor = peermonitor.NewPeerMonitor( - func(siteID int, connected bool, rtt time.Duration) { - if httpServer != nil { - // Find the site config to get endpoint information - var endpoint string - var isRelay bool - for _, site := range wgData.Sites { - if site.SiteId == siteID { - endpoint = site.Endpoint - // TODO: We'll need to track relay status separately - // For now, assume not using relay unless we get relay data - isRelay = !doHolepunch - break - } - } - httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) - } - if connected { - logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) - } else { - logger.Warn("Peer %d is disconnected", siteID) - } - }, - fixKey(privateKey.String()), - olm, - dev, - doHolepunch, - ) - - for i := range wgData.Sites { - site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice - if httpServer != nil { - httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) - } - - // Format the endpoint before configuring the peer. - site.Endpoint = formatEndpoint(site.Endpoint) - - if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { - logger.Error("Failed to configure peer: %v", err) - return - } - if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { - logger.Error("Failed to add route for peer: %v", err) - return - } - if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { - logger.Error("Failed to add routes for remote subnets: %v", err) - return - } - - logger.Info("Configured peer %s", site.PublicKey) - } - - peerMonitor.Start() - - connected = true - - logger.Info("WireGuard device created.") - }) - - olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateData UpdatePeerData - if err := json.Unmarshal(jsonData, &updateData); err != nil { - logger.Error("Error unmarshaling update data: %v", err) - return - } - - // Convert to SiteConfig - siteConfig := SiteConfig{ - SiteId: updateData.SiteId, - Endpoint: updateData.Endpoint, - PublicKey: updateData.PublicKey, - ServerIP: updateData.ServerIP, - ServerPort: updateData.ServerPort, - RemoteSubnets: updateData.RemoteSubnets, - } - - // Update the peer in WireGuard - if dev != nil { - // Find the existing peer to get old data - var oldRemoteSubnets string - var oldPublicKey string - for _, site := range wgData.Sites { - if site.SiteId == updateData.SiteId { - oldRemoteSubnets = site.RemoteSubnets - oldPublicKey = site.PublicKey - break - } - } - - // If the public key has changed, remove the old peer first - if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { - logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) - if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { - logger.Error("Failed to remove old peer: %v", err) - return - } - } - - // Format the endpoint before updating the peer. - siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } - - // Remove old remote subnet routes if they changed - if oldRemoteSubnets != siteConfig.RemoteSubnets { - if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { - logger.Error("Failed to remove old remote subnet routes: %v", err) - // Continue anyway to add new routes - } - - // Add new remote subnet routes - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { - logger.Error("Failed to add new remote subnet routes: %v", err) - return - } - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - for i := range wgData.Sites { - if wgData.Sites[i].SiteId == updateData.SiteId { - wgData.Sites[i] = siteConfig - break - } - } - } else { - logger.Error("WireGuard device not initialized") - } - }) - - // Handler for adding a new peer - olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var addData AddPeerData - if err := json.Unmarshal(jsonData, &addData); err != nil { - logger.Error("Error unmarshaling add data: %v", err) - return - } - - // Convert to SiteConfig - siteConfig := SiteConfig{ - SiteId: addData.SiteId, - Endpoint: addData.Endpoint, - PublicKey: addData.PublicKey, - ServerIP: addData.ServerIP, - ServerPort: addData.ServerPort, - RemoteSubnets: addData.RemoteSubnets, - } - - // Add the peer to WireGuard - if dev != nil { - // Format the endpoint before adding the new peer. - siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { - logger.Error("Failed to add route for new peer: %v", err) - return - } - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { - logger.Error("Failed to add routes for remote subnets: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", addData.SiteId) - - // Update WgData with the new peer - wgData.Sites = append(wgData.Sites, siteConfig) - } else { - logger.Error("WireGuard device not initialized") - } - }) - - // Handler for removing a peer - olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeData RemovePeerData - if err := json.Unmarshal(jsonData, &removeData); err != nil { - logger.Error("Error unmarshaling remove data: %v", err) - return - } - - // Find the peer to remove - var peerToRemove *SiteConfig - var newSites []SiteConfig - - for _, site := range wgData.Sites { - if site.SiteId == removeData.SiteId { - peerToRemove = &site - } else { - newSites = append(newSites, site) - } - } - - if peerToRemove == nil { - logger.Error("Peer with site ID %d not found", removeData.SiteId) - return - } - - // Remove the peer from WireGuard - if dev != nil { - if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { - logger.Error("Failed to remove peer: %v", err) - // Send error response if needed - return - } - - // Remove route for the peer - err = removeRouteForServerIP(peerToRemove.ServerIP) - if err != nil { - logger.Error("Failed to remove route for peer: %v", err) - return - } - - // Remove routes for remote subnets - if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { - logger.Error("Failed to remove routes for remote subnets: %v", err) - return - } - - // Remove successful - logger.Info("Successfully removed peer for site %d", removeData.SiteId) - - // Update WgData to remove the peer - wgData.Sites = newSites - } else { - logger.Error("WireGuard device not initialized") - } - }) - - olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { - logger.Debug("Received relay-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData RelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := resolveDomain(relayData.Endpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - if httpServer != nil { - httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) - } - - peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) - }) - - olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) { - logger.Info("Received no-sites message - no sites available for connection") - - // if stopRegister != nil { - // stopRegister() - // stopRegister = nil - // } - - // select { - // case <-stopHolepunch: - // // Channel already closed, do nothing - // default: - // close(stopHolepunch) - // } - - logger.Info("No sites available - stopped registration and holepunch processes") - }) - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - olm.Close() - }) - - olm.OnConnect(func() error { - logger.Info("Websocket Connected") - - if httpServer != nil { - httpServer.SetConnectionStatus(true) - } - - // CRITICAL: Save our full config AFTER websocket saves its limited config - // This ensures all 13 fields are preserved, not just the 4 that websocket saves - if err := SaveConfig(config); err != nil { - logger.Error("Failed to save full olm config: %v", err) - } else { - logger.Debug("Saved full olm config with all options") - } - - if connected { - logger.Debug("Already connected, skipping registration") - return nil - } - - publicKey := privateKey.PublicKey() - - logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) - - stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": !doHolepunch, - "olmVersion": olmVersion, - }, 1*time.Second) - - go keepSendingPing(olm) - - logger.Info("Sent registration message") - return nil - }) - - olm.OnTokenUpdate(func(token string) { - olmToken = token - }) - - // Connect to the WebSocket server - if err := olm.Connect(); err != nil { - logger.Fatal("Failed to connect to server: %v", err) - } - defer olm.Close() - - // Wait for interrupt signal or context cancellation - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - + // Wait for either signal or programmatic shutdown select { - case <-sigCh: - logger.Info("Received interrupt signal") + case <-signalCtx.Done(): + logger.Info("Shutdown signal received, cleaning up...") case <-ctx.Done(): - logger.Info("Context cancelled") + logger.Info("Shutdown requested via API, cleaning up...") } - select { - case <-stopHolepunch: - // Channel already closed, do nothing - default: - close(stopHolepunch) - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - select { - case <-stopPing: - // Channel already closed - default: - close(stopPing) - } - - if uapiListener != nil { - uapiListener.Close() - } - if dev != nil { - dev.Close() - } - - logger.Info("runOlmMain() exiting") - fmt.Printf("runOlmMain() exiting\n") + // Clean up resources + olm.Close() + logger.Info("Shutdown complete") } diff --git a/namespace.sh b/namespace.sh new file mode 100644 index 0000000..c1c3828 --- /dev/null +++ b/namespace.sh @@ -0,0 +1,126 @@ +#!/bin/bash + +# Configuration +NS_NAME="isolated_ns" # Name of the namespace +VETH_HOST="veth_host" # Interface name on host side +VETH_NS="veth_ns" # Interface name inside namespace +HOST_IP="192.168.15.1" # Gateway IP for the namespace (host side) +NS_IP="192.168.15.2" # IP address for the namespace +SUBNET_CIDR="24" # Subnet mask +DNS_SERVER="8.8.8.8" # DNS to use inside namespace + +# Detect the main physical interface (gateway to internet) +PHY_IFACE=$(ip route get 8.8.8.8 | awk -- '{printf $5}') + +# Helper function to check for root +check_root() { + if [ "$EUID" -ne 0 ]; then + echo "Error: This script must be run as root." + exit 1 + fi +} + +setup_ns() { + echo "Bringing up namespace '$NS_NAME'..." + + # 1. Create the network namespace + if ip netns list | grep -q "$NS_NAME"; then + echo "Namespace $NS_NAME already exists. Run 'down' first." + exit 1 + fi + ip netns add "$NS_NAME" + + # 2. Create veth pair + ip link add "$VETH_HOST" type veth peer name "$VETH_NS" + + # 3. Move peer interface to namespace + ip link set "$VETH_NS" netns "$NS_NAME" + + # 4. Configure Host Side Interface + ip addr add "${HOST_IP}/${SUBNET_CIDR}" dev "$VETH_HOST" + ip link set "$VETH_HOST" up + + # 5. Configure Namespace Side Interface + ip netns exec "$NS_NAME" ip addr add "${NS_IP}/${SUBNET_CIDR}" dev "$VETH_NS" + ip netns exec "$NS_NAME" ip link set "$VETH_NS" up + + # 6. Bring up loopback inside namespace (crucial for many apps) + ip netns exec "$NS_NAME" ip link set lo up + + # 7. Routing: Add default gateway inside namespace pointing to host + ip netns exec "$NS_NAME" ip route add default via "$HOST_IP" + + # 8. Enable IP forwarding on host + echo 1 > /proc/sys/net/ipv4/ip_forward + + # 9. NAT/Masquerade: Allow traffic from namespace to go out physical interface + # We verify rule doesn't exist first to avoid duplicates + iptables -t nat -C POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null || \ + iptables -t nat -A POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE + + # Allow forwarding from host veth to WAN and back + iptables -C FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null || \ + iptables -A FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT + + iptables -C FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null || \ + iptables -A FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT + + # 10. DNS Setup + # Netns uses /etc/netns//resolv.conf if it exists + mkdir -p "/etc/netns/$NS_NAME" + echo "nameserver $DNS_SERVER" > "/etc/netns/$NS_NAME/resolv.conf" + + echo "Namespace $NS_NAME is UP." + echo "To enter shell: sudo ip netns exec $NS_NAME bash" +} + +teardown_ns() { + echo "Tearing down namespace '$NS_NAME'..." + + # 1. Remove Namespace (this automatically deletes the veth pair inside it) + # The host side veth usually disappears when the peer is destroyed. + if ip netns list | grep -q "$NS_NAME"; then + ip netns del "$NS_NAME" + else + echo "Namespace $NS_NAME does not exist." + fi + + # 2. Clean up veth host side if it still lingers + if ip link show "$VETH_HOST" > /dev/null 2>&1; then + ip link delete "$VETH_HOST" + fi + + # 3. Remove iptables rules + # We use -D to delete the specific rules we added + iptables -t nat -D POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null + iptables -D FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null + iptables -D FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null + + # 4. Remove DNS config + rm -rf "/etc/netns/$NS_NAME" + + echo "Namespace $NS_NAME is DOWN." +} + +test_connectivity() { + echo "Testing connectivity inside $NS_NAME..." + ip netns exec "$NS_NAME" ping -c 3 8.8.8.8 +} + +# Main execution logic +check_root + +case "$1" in + up) + setup_ns + ;; + down) + teardown_ns + ;; + test) + test_connectivity + ;; + *) + echo "Usage: $0 {up|down|test}" + exit 1 +esac \ No newline at end of file diff --git a/olm/olm.go b/olm/olm.go new file mode 100644 index 0000000..1f02d8e --- /dev/null +++ b/olm/olm.go @@ -0,0 +1,1066 @@ +package olm + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os" + "runtime" + "strconv" + "strings" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/clients/permissions" + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/api" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + dnsOverride "github.com/fosrl/olm/dns/override" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var ( + privateKey wgtypes.Key + connected bool + dev *device.Device + uapiListener net.Listener + tdev tun.Device + middleDev *olmDevice.MiddleDevice + dnsProxy *dns.DNSProxy + apiServer *api.API + olmClient *websocket.Client + tunnelCancel context.CancelFunc + tunnelRunning bool + sharedBind *bind.SharedBind + holePunchManager *holepunch.Manager + globalConfig GlobalConfig + tunnelConfig TunnelConfig + globalCtx context.Context + stopRegister func() + stopPeerSend func() + updateRegister func(newData interface{}) + stopPing chan struct{} + peerManager *peers.PeerManager +) + +// initTunnelInfo creates the shared UDP socket and holepunch manager. +// This is used during initial tunnel setup and when switching organizations. +func initTunnelInfo(clientID string) error { + var err error + privateKey, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Error("Failed to generate private key: %v", err) + return err + } + + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) + if err != nil { + return fmt.Errorf("failed to find available UDP port: %w", err) + } + + localAddr := &net.UDPAddr{ + Port: int(sourcePort), + IP: net.IPv4zero, + } + + udpConn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return fmt.Errorf("failed to create UDP socket: %w", err) + } + + sharedBind, err = bind.New(udpConn) + if err != nil { + udpConn.Close() + return fmt.Errorf("failed to create shared bind: %w", err) + } + + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) + sharedBind.AddRef() + + logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) + + // Create the holepunch manager + holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) + + return nil +} + +func Init(ctx context.Context, config GlobalConfig) { + globalConfig = config + globalCtx = ctx + + logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + + logger.Debug("Checking permissions for native interface") + err := permissions.CheckNativeInterfacePermissions() + if err != nil { + logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) + return + } + + if config.HTTPAddr != "" { + apiServer = api.NewAPI(config.HTTPAddr) + } else if config.SocketPath != "" { + apiServer = api.NewAPISocket(config.SocketPath) + } + + apiServer.SetVersion(config.Version) + apiServer.SetAgent(config.Agent) + + // Set up API handlers + apiServer.SetHandlers( + // onConnect + func(req api.ConnectionRequest) error { + logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) + + // Stop any existing tunnel before starting a new one + if olmClient != nil { + logger.Info("Stopping existing tunnel before starting new connection") + StopTunnel() + } + + tunnelConfig := TunnelConfig{ + Endpoint: req.Endpoint, + ID: req.ID, + Secret: req.Secret, + UserToken: req.UserToken, + MTU: req.MTU, + DNS: req.DNS, + UpstreamDNS: req.UpstreamDNS, + InterfaceName: req.InterfaceName, + Holepunch: req.Holepunch, + TlsClientCert: req.TlsClientCert, + OrgID: req.OrgID, + } + + var err error + // Parse ping interval + if req.PingInterval != "" { + tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval) + if err != nil { + logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval) + tunnelConfig.PingIntervalDuration = 3 * time.Second + } + } else { + tunnelConfig.PingIntervalDuration = 3 * time.Second + } + // Parse ping timeout + if req.PingTimeout != "" { + tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout) + if err != nil { + logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout) + tunnelConfig.PingTimeoutDuration = 5 * time.Second + } + } else { + tunnelConfig.PingTimeoutDuration = 5 * time.Second + } + if req.MTU == 0 { + tunnelConfig.MTU = 1420 + } + if req.DNS == "" { + tunnelConfig.DNS = "9.9.9.9" + } + // DNSProxyIP has no default - it must be provided if DNS proxy is desired + // UpstreamDNS defaults to 8.8.8.8 if not provided + if len(req.UpstreamDNS) == 0 { + tunnelConfig.UpstreamDNS = []string{"8.8.8.8:53"} + } + if req.InterfaceName == "" { + tunnelConfig.InterfaceName = "olm" + } + + // Start the tunnel process with the new credentials + if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { + logger.Info("Starting tunnel with new credentials") + go StartTunnel(tunnelConfig) + } + + return nil + }, + // onSwitchOrg + func(req api.SwitchOrgRequest) error { + logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) + return SwitchOrg(req.OrgID) + }, + // onDisconnect + func() error { + logger.Info("Processing disconnect request via API") + return StopTunnel() + }, + // onExit + func() error { + logger.Info("Processing shutdown request via API") + Close() + if globalConfig.OnExit != nil { + globalConfig.OnExit() + } + return nil + }, + ) +} + +func StartTunnel(config TunnelConfig) { + if tunnelRunning { + logger.Info("Tunnel already running") + return + } + + tunnelRunning = true // Also set it here in case it is called externally + tunnelConfig = config + + // Reset terminated status when tunnel starts + apiServer.SetTerminated(false) + + // debug print out the whole config + logger.Debug("Starting tunnel with config: %+v", config) + + // Create a cancellable context for this tunnel process + tunnelCtx, cancel := context.WithCancel(globalCtx) + tunnelCancel = cancel + defer func() { + tunnelCancel = nil + }() + + // Recreate channels for this tunnel session + stopPing = make(chan struct{}) + + var ( + interfaceName = config.InterfaceName + id = config.ID + secret = config.Secret + userToken = config.UserToken + ) + + apiServer.SetOrgID(config.OrgID) + + // Create a new olm client using the provided credentials + olm, err := websocket.NewClient( + id, // Use provided ID + secret, // Use provided secret + userToken, // Use provided user token OPTIONAL + config.OrgID, + config.Endpoint, // Use provided endpoint + config.PingIntervalDuration, + config.PingTimeoutDuration, + ) + if err != nil { + logger.Error("Failed to create olm: %v", err) + return + } + + // Store the client reference globally + olmClient = olm + + // Create shared UDP socket and holepunch manager + if err := initTunnelInfo(id); err != nil { + logger.Error("%v", err) + return + } + + olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + var wgData WgData + + if connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + + if stopRegister != nil { + stopRegister() + stopRegister = nil + } + + if updateRegister != nil { + updateRegister = nil + } + + // if there is an existing tunnel then close it + if dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + tdev, err = func() (tun.Device, error) { + if config.FileDescriptorTun != 0 { + return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) + } + var ifName = interfaceName + if runtime.GOOS == "darwin" { // this is if we dont pass a fd + ifName, err = network.FindUnusedUTUN() + if err != nil { + return nil, err + } + } + return tun.CreateTUN(ifName, config.MTU) + }() + + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + interfaceName = realInterfaceName + } + } + + // Wrap TUN device with packet filter for DNS proxy + middleDev = olmDevice.NewMiddleDevice(tdev) + + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") + // Use filtered device instead of raw TUN device + dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) + + if config.EnableUAPI { + fileUAPI, err := func() (*os.File, error) { + if config.FileDescriptorUAPI != 0 { + fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + } + return os.NewFile(uintptr(fd), ""), nil + } + return olmDevice.UapiOpen(interfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + uapiListener, err = olmDevice.UapiListen(interfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := uapiListener.Accept() + if err != nil { + + return + } + go dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + } + + if err = dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + + // Create and start DNS proxy + dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { + logger.Error("Failed to configure interface: %v", err) + } + + if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) + } + + // TODO: seperate adding the callback to this so we can init it above with the interface + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + + // Create peer manager with integrated peer monitoring + peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: dev, + DNSProxy: dnsProxy, + InterfaceName: interfaceName, + PrivateKey: privateKey, + MiddleDev: middleDev, + LocalIP: interfaceIP, + SharedBind: sharedBind, + WSClient: olm, + APIServer: apiServer, + }) + + for i := range wgData.Sites { + site := wgData.Sites[i] + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) + + if err := peerManager.AddPeer(site); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) + } + + peerManager.Start() + + if err := dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime + logger.Error("Failed to start DNS proxy: %v", err) + } + + if config.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } + } + + apiServer.SetRegistered(true) + + connected = true + + // Invoke onConnected callback if configured + if globalConfig.OnConnected != nil { + go globalConfig.OnConnected() + } + + logger.Info("WireGuard device created.") + }) + + olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData peers.SiteConfig + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Get existing peer from PeerManager + existingPeer, exists := peerManager.GetPeer(updateData.SiteId) + if !exists { + logger.Error("Peer with site ID %d not found", updateData.SiteId) + return + } + + // Create updated site config by merging with existing data + siteConfig := existingPeer + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + if err := peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // Update successful + logger.Info("Successfully updated peer for site %d", updateData.SiteId) + }) + + // Handler for adding a new peer + olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) + + if stopPeerSend != nil { + stopPeerSend() + stopPeerSend = nil + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var siteConfig peers.SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + + if err := peerManager.AddPeer(siteConfig); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + // Add successful + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) + }) + + // Handler for removing a peer + olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData peers.PeerRemove + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + if err := peerManager.RemovePeer(removeData.SiteId); err != nil { + logger.Error("Failed to remove peer: %v", err) + return + } + + // Remove successful + logger.Info("Successfully removed peer for site %d", removeData.SiteId) + }) + + // Handler for adding remote subnets to a peer + olm.RegisterHandler("olm/wg/peer/data/add", func(msg websocket.WSMessage) { + logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addSubnetsData peers.PeerAdd + if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { + logger.Error("Error unmarshaling add-remote-subnets data: %v", err) + return + } + + if _, exists := peerManager.GetPeer(addSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) + return + } + + // Add new subnets + for _, subnet := range addSubnetsData.RemoteSubnets { + if err := peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases + for _, alias := range addSubnetsData.Aliases { + if err := peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + }) + + // Handler for removing remote subnets from a peer + olm.RegisterHandler("olm/wg/peer/data/remove", func(msg websocket.WSMessage) { + logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeSubnetsData peers.RemovePeerData + if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { + logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) + return + } + + if _, exists := peerManager.GetPeer(removeSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) + return + } + + // Remove subnets + for _, subnet := range removeSubnetsData.RemoteSubnets { + if err := peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Remove aliases + for _, alias := range removeSubnetsData.Aliases { + if err := peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + }) + + // Handler for updating remote subnets of a peer (remove old, add new in one operation) + olm.RegisterHandler("olm/wg/peer/data/update", func(msg websocket.WSMessage) { + logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateSubnetsData peers.UpdatePeerData + if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { + logger.Error("Error unmarshaling update-remote-subnets data: %v", err) + return + } + + if _, exists := peerManager.GetPeer(updateSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", updateSubnetsData.SiteId) + return + } + + // Add new subnets BEFORE removing old ones to preserve shared subnets + // This ensures that if an old and new subnet are the same on different peers, + // the route won't be temporarily removed + for _, subnet := range updateSubnetsData.NewRemoteSubnets { + if err := peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Remove old subnets after new ones are added + for _, subnet := range updateSubnetsData.OldRemoteSubnets { + if err := peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases BEFORE removing old ones to preserve shared IP addresses + // This ensures that if an old and new alias share the same IP, the IP won't be + // temporarily removed from the allowed IPs list + for _, alias := range updateSubnetsData.NewAliases { + if err := peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + + // Remove old aliases after new ones are added + for _, alias := range updateSubnetsData.OldAliases { + if err := peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + + logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) + }) + + olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.RelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + peerManager.RelayPeer(relayData.SiteId, primaryRelay) + }) + + olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) + + peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) + }) + + // Handler for peer handshake - adds exit node to holepunch rotation and notifies server + olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + PublicKey: handshakeData.ExitNode.PublicKey, + } + + added := holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) + }) + + olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { + logger.Info("Received terminate message") + apiServer.SetTerminated(true) + apiServer.SetConnectionStatus(false) + apiServer.SetRegistered(false) + apiServer.ClearPeerStatuses() + network.ClearNetworkSettings() + Close() + + if globalConfig.OnTerminated != nil { + go globalConfig.OnTerminated() + } + }) + + olm.RegisterHandler("pong", func(msg websocket.WSMessage) { + logger.Debug("Received pong message") + }) + + olm.OnConnect(func() error { + logger.Info("Websocket Connected") + + apiServer.SetConnectionStatus(true) + + if connected { + logger.Debug("Already connected, skipping registration") + return nil + } + + publicKey := privateKey.PublicKey() + + // delay for 500ms to allow for time for the hp to get processed + time.Sleep(500 * time.Millisecond) + + if stopRegister == nil { + logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) + stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": globalConfig.Version, + "olmAgent": globalConfig.Agent, + "orgId": config.OrgID, + "userToken": userToken, + }, 1*time.Second) + + // Invoke onRegistered callback if configured + if globalConfig.OnRegistered != nil { + go globalConfig.OnRegistered() + } + } + + go keepSendingPing(olm) + + return nil + }) + + olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + holePunchManager.SetToken(token) + + logger.Debug("Got exit nodes for hole punching: %v", exitNodes) + + // Convert websocket.ExitNode to holepunch.ExitNode + hpExitNodes := make([]holepunch.ExitNode, len(exitNodes)) + for i, node := range exitNodes { + hpExitNodes[i] = holepunch.ExitNode{ + Endpoint: node.Endpoint, + PublicKey: node.PublicKey, + } + } + + logger.Debug("Updated hole punch exit nodes: %v", hpExitNodes) + + // Start hole punching using the manager + logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) + if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + logger.Warn("Failed to start hole punch: %v", err) + } + }) + + olm.OnAuthError(func(statusCode int, message string) { + logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) + apiServer.SetTerminated(true) + apiServer.SetConnectionStatus(false) + apiServer.SetRegistered(false) + apiServer.ClearPeerStatuses() + network.ClearNetworkSettings() + + Close() + + if globalConfig.OnAuthError != nil { + go globalConfig.OnAuthError(statusCode, message) + } + + if globalConfig.OnTerminated != nil { + go globalConfig.OnTerminated() + } + }) + + // Connect to the WebSocket server + if err := olm.Connect(); err != nil { + logger.Error("Failed to connect to server: %v", err) + return + } + defer olm.Close() + + // Wait for context cancellation + <-tunnelCtx.Done() + logger.Info("Tunnel process context cancelled, cleaning up") +} + +func Close() { + // Restore original DNS configuration + // we do this first to avoid any DNS issues if something else gets stuck + if err := dnsOverride.RestoreDNSOverride(); err != nil { + logger.Error("Failed to restore DNS: %v", err) + } + + // Stop hole punch manager + if holePunchManager != nil { + holePunchManager.Stop() + holePunchManager = nil + } + + if stopPing != nil { + select { + case <-stopPing: + // Channel already closed + default: + close(stopPing) + } + } + + if stopRegister != nil { + stopRegister() + stopRegister = nil + } + + if updateRegister != nil { + updateRegister = nil + } + + if peerManager != nil { + peerManager.Close() // Close() also calls Stop() internally + peerManager = nil + } + + if uapiListener != nil { + uapiListener.Close() + uapiListener = nil + } + + // Stop DNS proxy first - it uses the middleDev for packet filtering + logger.Debug("Stopping DNS proxy") + if dnsProxy != nil { + dnsProxy.Stop() + dnsProxy = nil + } + + // Close MiddleDevice first - this closes the TUN and signals the closed channel + // This unblocks the pump goroutine and allows WireGuard's TUN reader to exit + logger.Debug("Closing MiddleDevice") + if middleDev != nil { + middleDev.Close() + middleDev = nil + } + // Note: tdev is closed by middleDev.Close() since middleDev wraps it + tdev = nil + + // Now close WireGuard device - its TUN reader should have exited by now + logger.Debug("Closing WireGuard device") + if dev != nil { + dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference + dev = nil + } + + // Release the hole punch reference to the shared bind + if sharedBind != nil { + // Release hole punch reference (WireGuard already released its reference via dev.Close()) + logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) + sharedBind.Release() + sharedBind = nil + logger.Info("Released shared UDP bind") + } + + logger.Info("Olm service stopped") +} + +// StopTunnel stops just the tunnel process and websocket connection +// without shutting down the entire application +func StopTunnel() error { + logger.Info("Stopping tunnel process") + + // Cancel the tunnel context if it exists + if tunnelCancel != nil { + tunnelCancel() + // Give it a moment to clean up + time.Sleep(200 * time.Millisecond) + } + + // Close the websocket connection + if olmClient != nil { + olmClient.Close() + olmClient = nil + } + + Close() + + // Reset the connected state + connected = false + tunnelRunning = false + + // Update API server status + apiServer.SetConnectionStatus(false) + apiServer.SetRegistered(false) + + network.ClearNetworkSettings() + apiServer.ClearPeerStatuses() + + logger.Info("Tunnel process stopped") + + return nil +} + +func StopApi() error { + if apiServer != nil { + err := apiServer.Stop() + if err != nil { + return fmt.Errorf("failed to stop API server: %w", err) + } + } + return nil +} + +func StartApi() error { + if apiServer != nil { + err := apiServer.Start() + if err != nil { + return fmt.Errorf("failed to start API server: %w", err) + } + } + return nil +} + +func GetStatus() api.StatusResponse { + return apiServer.GetStatus() +} + +func SwitchOrg(orgID string) error { + logger.Info("Processing org switch request to orgId: %s", orgID) + // stop the tunnel + if err := StopTunnel(); err != nil { + return fmt.Errorf("failed to stop existing tunnel: %w", err) + } + + // Update the org ID in the API server and global config + apiServer.SetOrgID(orgID) + + tunnelConfig.OrgID = orgID + + // Restart the tunnel with the same config but new org ID + go StartTunnel(tunnelConfig) + + return nil +} diff --git a/olm/types.go b/olm/types.go new file mode 100644 index 0000000..993bb56 --- /dev/null +++ b/olm/types.go @@ -0,0 +1,66 @@ +package olm + +import ( + "time" + + "github.com/fosrl/olm/peers" +) + +type WgData struct { + Sites []peers.SiteConfig `json:"sites"` + TunnelIP string `json:"tunnelIP"` + UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses +} + +type GlobalConfig struct { + // Logging + LogLevel string + + // HTTP server + EnableAPI bool + HTTPAddr string + SocketPath string + Version string + Agent string + + // Callbacks + OnRegistered func() + OnConnected func() + OnTerminated func() + OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) + OnExit func() // Called when exit is requested via API +} + +type TunnelConfig struct { + // Connection settings + Endpoint string + ID string + Secret string + UserToken string + + // Network settings + MTU int + DNS string + UpstreamDNS []string + InterfaceName string + + // Advanced + Holepunch bool + TlsClientCert string + + // Parsed values (not in JSON) + PingIntervalDuration time.Duration + PingTimeoutDuration time.Duration + + OrgID string + // DoNotCreateNewClient bool + + FileDescriptorTun uint32 + FileDescriptorUAPI uint32 + + EnableUAPI bool + + OverrideDNS bool + + DisableRelay bool +} diff --git a/olm/util.go b/olm/util.go new file mode 100644 index 0000000..9da1f00 --- /dev/null +++ b/olm/util.go @@ -0,0 +1,98 @@ +package olm + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "github.com/fosrl/olm/websocket" +) + +// Helper function to format endpoints correctly +func formatEndpoint(endpoint string) string { + if endpoint == "" { + return "" + } + // Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080) + _, _, err := net.SplitHostPort(endpoint) + if err == nil { + return endpoint // Already valid, no change needed + } + + // If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it. + lastColon := strings.LastIndex(endpoint, ":") + if lastColon > 0 { // Ensure there is a colon and it's not the first character + hostPart := endpoint[:lastColon] + // Check if the host part is a literal IPv6 address + if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil { + // It is! Reformat it with brackets. + portPart := endpoint[lastColon+1:] + return fmt.Sprintf("[%s]:%s", hostPart, portPart) + } + } + + // If it's not the specific malformed case, return it as is. + return endpoint +} + +func sendPing(olm *websocket.Client) error { + err := olm.SendMessage("olm/ping", map[string]interface{}{ + "timestamp": time.Now().Unix(), + "userToken": olm.GetConfig().UserToken, + }) + if err != nil { + logger.Error("Failed to send ping message: %v", err) + return err + } + logger.Debug("Sent ping message") + return nil +} + +func keepSendingPing(olm *websocket.Client) { + // Send ping immediately on startup + if err := sendPing(olm); err != nil { + logger.Error("Failed to send initial ping: %v", err) + } else { + logger.Info("Sent initial ping message") + } + + // Set up ticker for one minute intervals + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-stopPing: + logger.Info("Stopping ping messages") + return + case <-ticker.C: + if err := sendPing(olm); err != nil { + logger.Error("Failed to send periodic ping: %v", err) + } + } + } +} + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} + +// stringSlicesEqual compares two string slices for equality +func stringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go deleted file mode 100644 index afa8248..0000000 --- a/peermonitor/peermonitor.go +++ /dev/null @@ -1,331 +0,0 @@ -package peermonitor - -import ( - "context" - "fmt" - "strings" - "sync" - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/websocket" - "github.com/fosrl/olm/wgtester" - "golang.zx2c4.com/wireguard/device" -) - -// PeerMonitorCallback is the function type for connection status change callbacks -type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) - -// WireGuardConfig holds the WireGuard configuration for a peer -type WireGuardConfig struct { - SiteID int - PublicKey string - ServerIP string - Endpoint string - PrimaryRelay string // The primary relay endpoint -} - -// PeerMonitor handles monitoring the connection status to multiple WireGuard peers -type PeerMonitor struct { - monitors map[int]*wgtester.Client - configs map[int]*WireGuardConfig - callback PeerMonitorCallback - mutex sync.Mutex - running bool - interval time.Duration - timeout time.Duration - maxAttempts int - privateKey string - wsClient *websocket.Client - device *device.Device - handleRelaySwitch bool // Whether to handle relay switching -} - -// NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor { - return &PeerMonitor{ - monitors: make(map[int]*wgtester.Client), - configs: make(map[int]*WireGuardConfig), - callback: callback, - interval: 1 * time.Second, // Default check interval - timeout: 2500 * time.Millisecond, - maxAttempts: 8, - privateKey: privateKey, - wsClient: wsClient, - device: device, - handleRelaySwitch: handleRelaySwitch, - } -} - -// SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.interval = interval - - // Update interval for all existing monitors - for _, client := range pm.monitors { - client.SetPacketInterval(interval) - } -} - -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.timeout = timeout - - // Update timeout for all existing monitors - for _, client := range pm.monitors { - client.SetTimeout(timeout) - } -} - -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.maxAttempts = attempts - - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) - } -} - -// AddPeer adds a new peer to monitor -func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - // Check if we're already monitoring this peer - if _, exists := pm.monitors[siteID]; exists { - // Update the endpoint instead of creating a new monitor - pm.removePeerUnlocked(siteID) - } - - client, err := wgtester.NewClient(endpoint) - if err != nil { - return err - } - - // Configure the client with our settings - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - - // Store the client and config - pm.monitors[siteID] = client - pm.configs[siteID] = wgConfig - - // If monitor is already running, start monitoring this peer - if pm.running { - siteIDCopy := siteID // Create a copy for the closure - err = client.StartMonitor(func(status wgtester.ConnectionStatus) { - pm.handleConnectionStatusChange(siteIDCopy, status) - }) - } - - return err -} - -// removePeerUnlocked stops monitoring a peer and removes it from the monitor -// This function assumes the mutex is already held by the caller -func (pm *PeerMonitor) removePeerUnlocked(siteID int) { - client, exists := pm.monitors[siteID] - if !exists { - return - } - - client.StopMonitor() - client.Close() - delete(pm.monitors, siteID) - delete(pm.configs, siteID) -} - -// RemovePeer stops monitoring a peer and removes it from the monitor -func (pm *PeerMonitor) RemovePeer(siteID int) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.removePeerUnlocked(siteID) -} - -// Start begins monitoring all peers -func (pm *PeerMonitor) Start() { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - if pm.running { - return // Already running - } - - pm.running = true - - // Start monitoring all peers - for siteID, client := range pm.monitors { - siteIDCopy := siteID // Create a copy for the closure - err := client.StartMonitor(func(status wgtester.ConnectionStatus) { - pm.handleConnectionStatusChange(siteIDCopy, status) - }) - if err != nil { - logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err) - continue - } - logger.Info("Started monitoring peer %d\n", siteID) - } -} - -// handleConnectionStatusChange is called when a peer's connection status changes -func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) { - // Call the user-provided callback first - if pm.callback != nil { - pm.callback(siteID, status.Connected, status.RTT) - } - - // If disconnected, handle failover - if !status.Connected { - // Send relay message to the server - if pm.wsClient != nil { - pm.sendRelay(siteID) - } - } -} - -// handleFailover handles failover to the relay server when a peer is disconnected -func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) { - pm.mutex.Lock() - config, exists := pm.configs[siteID] - pm.mutex.Unlock() - - if !exists { - return - } - - // Check for IPv6 and format the endpoint correctly - formattedEndpoint := relayEndpoint - if strings.Contains(relayEndpoint, ":") { - formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint) - } - - // Configure WireGuard to use the relay - wgConfig := fmt.Sprintf(`private_key=%s -public_key=%s -allowed_ip=%s/32 -endpoint=%s:21820 -persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint) - - err := pm.device.IpcSet(wgConfig) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v\n", err) - return - } - - logger.Info("Adjusted peer %d to point to relay!\n", siteID) -} - -// sendRelay sends a relay message to the server -func (pm *PeerMonitor) sendRelay(siteID int) error { - if !pm.handleRelaySwitch { - return nil - } - - if pm.wsClient == nil { - return fmt.Errorf("websocket client is nil") - } - - err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent relay message") - return nil -} - -// Stop stops monitoring all peers -func (pm *PeerMonitor) Stop() { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - if !pm.running { - return - } - - pm.running = false - - // Stop all monitors - for _, client := range pm.monitors { - client.StopMonitor() - } -} - -// Close stops monitoring and cleans up resources -func (pm *PeerMonitor) Close() { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - // Stop and close all clients - for siteID, client := range pm.monitors { - client.StopMonitor() - client.Close() - delete(pm.monitors, siteID) - } - - pm.running = false -} - -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() - - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } - - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() - - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} - -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*wgtester.Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() - - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() - - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } - - return results -} diff --git a/peers/manager.go b/peers/manager.go new file mode 100644 index 0000000..59af2ce --- /dev/null +++ b/peers/manager.go @@ -0,0 +1,884 @@ +package peers + +import ( + "fmt" + "net" + "strconv" + "strings" + "sync" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/api" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + "github.com/fosrl/olm/peers/monitor" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// PeerManagerConfig contains the configuration for creating a PeerManager +type PeerManagerConfig struct { + Device *device.Device + DNSProxy *dns.DNSProxy + InterfaceName string + PrivateKey wgtypes.Key + // For peer monitoring + MiddleDev *olmDevice.MiddleDevice + LocalIP string + SharedBind *bind.SharedBind + // WSClient is optional - if nil, relay messages won't be sent + WSClient *websocket.Client + APIServer *api.API +} + +type PeerManager struct { + mu sync.RWMutex + device *device.Device + peers map[int]SiteConfig + peerMonitor *monitor.PeerMonitor + dnsProxy *dns.DNSProxy + interfaceName string + privateKey wgtypes.Key + // allowedIPOwners tracks which peer currently "owns" each allowed IP in WireGuard + // key is the CIDR string, value is the siteId that has it configured in WG + allowedIPOwners map[string]int + // allowedIPClaims tracks all peers that claim each allowed IP + // key is the CIDR string, value is a set of siteIds that want this IP + allowedIPClaims map[string]map[int]bool + APIServer *api.API +} + +// NewPeerManager creates a new PeerManager with an internal PeerMonitor +func NewPeerManager(config PeerManagerConfig) *PeerManager { + pm := &PeerManager{ + device: config.Device, + peers: make(map[int]SiteConfig), + dnsProxy: config.DNSProxy, + interfaceName: config.InterfaceName, + privateKey: config.PrivateKey, + allowedIPOwners: make(map[string]int), + allowedIPClaims: make(map[string]map[int]bool), + APIServer: config.APIServer, + } + + // Create the peer monitor + pm.peerMonitor = monitor.NewPeerMonitor( + config.WSClient, + config.MiddleDev, + config.LocalIP, + config.SharedBind, + config.APIServer, + ) + + return pm +} + +func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { + pm.mu.RLock() + defer pm.mu.RUnlock() + peer, ok := pm.peers[siteId] + return peer, ok +} + +func (pm *PeerManager) GetAllPeers() []SiteConfig { + pm.mu.RLock() + defer pm.mu.RUnlock() + peers := make([]SiteConfig, 0, len(pm.peers)) + for _, peer := range pm.peers { + peers = append(peers, peer) + } + return peers +} + +func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + // build the allowed IPs list from the remote subnets and aliases and add them to the peer + allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) + allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...) + for _, alias := range siteConfig.Aliases { + allowedIPs = append(allowedIPs, alias.AliasAddress+"/32") + } + siteConfig.AllowedIps = allowedIPs + + // Register claims for all allowed IPs and determine which ones this peer will own + ownedIPs := make([]string, 0, len(allowedIPs)) + for _, ip := range allowedIPs { + pm.claimAllowedIP(siteConfig.SiteId, ip) + // Check if this peer became the owner + if pm.allowedIPOwners[ip] == siteConfig.SiteId { + ownedIPs = append(ownedIPs, ip) + } + } + + // Create a config with only the owned IPs for WireGuard + wgConfig := siteConfig + wgConfig.AllowedIps = ownedIPs + + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + return err + } + + if err := network.AddRouteForServerIP(siteConfig.ServerIP, pm.interfaceName); err != nil { + logger.Error("Failed to add route for server IP: %v", err) + } + if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + } + for _, alias := range siteConfig.Aliases { + address := net.ParseIP(alias.AliasAddress) + if address == nil { + continue + } + pm.dnsProxy.AddDNSRecord(alias.Alias, address) + } + + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] + monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port + + err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, siteConfig.Endpoint) // always use the real site endpoint for hole punch monitoring + if err != nil { + logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) + } else { + logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) + } + + pm.peers[siteConfig.SiteId] = siteConfig + + pm.APIServer.AddPeerStatus(siteConfig.SiteId, siteConfig.Name, false, 0, siteConfig.Endpoint, false) + + // Perform rapid initial holepunch test (outside of lock to avoid blocking) + // This quickly determines if holepunch is viable and triggers relay if not + go pm.performRapidInitialTest(siteConfig.SiteId, siteConfig.Endpoint) + + return nil +} + +func (pm *PeerManager) RemovePeer(siteId int) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil { + return err + } + + if err := network.RemoveRouteForServerIP(peer.ServerIP, pm.interfaceName); err != nil { + logger.Error("Failed to remove route for server IP: %v", err) + } + + // Only remove routes for subnets that aren't used by other peers + for _, subnet := range peer.RemoteSubnets { + subnetStillInUse := false + for otherSiteId, otherPeer := range pm.peers { + if otherSiteId == siteId { + continue // Skip the peer being removed + } + for _, otherSubnet := range otherPeer.RemoteSubnets { + if otherSubnet == subnet { + subnetStillInUse = true + break + } + } + if subnetStillInUse { + break + } + } + if !subnetStillInUse { + if err := network.RemoveRoutes([]string{subnet}); err != nil { + logger.Error("Failed to remove route for remote subnet %s: %v", subnet, err) + } + } + } + + // For aliases + for _, alias := range peer.Aliases { + address := net.ParseIP(alias.AliasAddress) + if address == nil { + continue + } + pm.dnsProxy.RemoveDNSRecord(alias.Alias, address) + } + + // Release all IP claims and promote other peers as needed + // Collect promotions first to avoid modifying while iterating + type promotion struct { + newOwner int + cidr string + } + var promotions []promotion + + for _, ip := range peer.AllowedIps { + newOwner, promoted := pm.releaseAllowedIP(siteId, ip) + if promoted && newOwner >= 0 { + promotions = append(promotions, promotion{newOwner: newOwner, cidr: ip}) + } + } + + // Apply promotions - update WireGuard config for newly promoted peers + // Group by peer to avoid multiple config updates + promotedPeers := make(map[int]bool) + for _, p := range promotions { + promotedPeers[p.newOwner] = true + logger.Info("Promoted peer %d to owner of IP %s", p.newOwner, p.cidr) + } + + for promotedPeerId := range promotedPeers { + if promotedPeer, exists := pm.peers[promotedPeerId]; exists { + // Build the list of IPs this peer now owns + ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) + wgConfig := promotedPeer + wgConfig.AllowedIps = ownedIPs + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) + } + } + } + + // Stop monitoring this peer + pm.peerMonitor.RemovePeer(siteId) + logger.Info("Stopped monitoring for site %d", siteId) + + pm.APIServer.RemovePeerStatus(siteId) + + delete(pm.peers, siteId) + return nil +} + +func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + oldPeer, exists := pm.peers[siteConfig.SiteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId) + } + + // If public key changed, remove old peer first + if siteConfig.PublicKey != oldPeer.PublicKey { + if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil { + logger.Error("Failed to remove old peer: %v", err) + } + } + + // Build the new allowed IPs list + newAllowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) + newAllowedIPs = append(newAllowedIPs, siteConfig.RemoteSubnets...) + for _, alias := range siteConfig.Aliases { + newAllowedIPs = append(newAllowedIPs, alias.AliasAddress+"/32") + } + siteConfig.AllowedIps = newAllowedIPs + + // Handle allowed IP claim changes + oldAllowedIPs := make(map[string]bool) + for _, ip := range oldPeer.AllowedIps { + oldAllowedIPs[ip] = true + } + newAllowedIPsSet := make(map[string]bool) + for _, ip := range newAllowedIPs { + newAllowedIPsSet[ip] = true + } + + // Track peers that need WireGuard config updates due to promotions + peersToUpdate := make(map[int]bool) + + // Release claims for removed IPs and handle promotions + for ip := range oldAllowedIPs { + if !newAllowedIPsSet[ip] { + newOwner, promoted := pm.releaseAllowedIP(siteConfig.SiteId, ip) + if promoted && newOwner >= 0 { + peersToUpdate[newOwner] = true + logger.Info("Promoted peer %d to owner of IP %s", newOwner, ip) + } + } + } + + // Add claims for new IPs + for ip := range newAllowedIPsSet { + if !oldAllowedIPs[ip] { + pm.claimAllowedIP(siteConfig.SiteId, ip) + } + } + + // Build the list of IPs this peer owns for WireGuard config + ownedIPs := pm.getOwnedAllowedIPs(siteConfig.SiteId) + wgConfig := siteConfig + wgConfig.AllowedIps = ownedIPs + + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + return err + } + + // Update WireGuard config for any promoted peers + for promotedPeerId := range peersToUpdate { + if promotedPeer, exists := pm.peers[promotedPeerId]; exists { + promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) + promotedWgConfig := promotedPeer + promotedWgConfig.AllowedIps = promotedOwnedIPs + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) + } + } + } + + // Handle remote subnet route changes + // Calculate added and removed subnets + oldSubnets := make(map[string]bool) + for _, s := range oldPeer.RemoteSubnets { + oldSubnets[s] = true + } + newSubnets := make(map[string]bool) + for _, s := range siteConfig.RemoteSubnets { + newSubnets[s] = true + } + + var addedSubnets []string + var removedSubnets []string + + for s := range newSubnets { + if !oldSubnets[s] { + addedSubnets = append(addedSubnets, s) + } + } + for s := range oldSubnets { + if !newSubnets[s] { + removedSubnets = append(removedSubnets, s) + } + } + + // Remove routes for removed subnets (only if no other peer needs them) + for _, subnet := range removedSubnets { + subnetStillInUse := false + for otherSiteId, otherPeer := range pm.peers { + if otherSiteId == siteConfig.SiteId { + continue // Skip the current peer (already updated) + } + for _, otherSubnet := range otherPeer.RemoteSubnets { + if otherSubnet == subnet { + subnetStillInUse = true + break + } + } + if subnetStillInUse { + break + } + } + if !subnetStillInUse { + if err := network.RemoveRoutes([]string{subnet}); err != nil { + logger.Error("Failed to remove route for subnet %s: %v", subnet, err) + } + } + } + + // Add routes for added subnets + if len(addedSubnets) > 0 { + if err := network.AddRoutes(addedSubnets, pm.interfaceName); err != nil { + logger.Error("Failed to add routes: %v", err) + } + } + + // Update aliases + // Remove old aliases + for _, alias := range oldPeer.Aliases { + address := net.ParseIP(alias.AliasAddress) + if address == nil { + continue + } + pm.dnsProxy.RemoveDNSRecord(alias.Alias, address) + } + // Add new aliases + for _, alias := range siteConfig.Aliases { + address := net.ParseIP(alias.AliasAddress) + if address == nil { + continue + } + pm.dnsProxy.AddDNSRecord(alias.Alias, address) + } + + pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) + + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] + monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port + pm.peerMonitor.UpdatePeerEndpoint(siteConfig.SiteId, monitorPeer) // +1 for monitor port + + pm.peers[siteConfig.SiteId] = siteConfig + return nil +} + +// claimAllowedIP registers a peer's claim to an allowed IP. +// If no other peer owns it in WireGuard, this peer becomes the owner. +// Must be called with lock held. +func (pm *PeerManager) claimAllowedIP(siteId int, cidr string) { + // Add to claims + if pm.allowedIPClaims[cidr] == nil { + pm.allowedIPClaims[cidr] = make(map[int]bool) + } + pm.allowedIPClaims[cidr][siteId] = true + + // If no owner yet, this peer becomes the owner + if _, hasOwner := pm.allowedIPOwners[cidr]; !hasOwner { + pm.allowedIPOwners[cidr] = siteId + } +} + +// releaseAllowedIP removes a peer's claim to an allowed IP. +// If this peer was the owner, it promotes another claimant to owner. +// Returns the new owner's siteId (or -1 if no new owner) and whether promotion occurred. +// Must be called with lock held. +func (pm *PeerManager) releaseAllowedIP(siteId int, cidr string) (newOwner int, promoted bool) { + // Remove from claims + if claims, exists := pm.allowedIPClaims[cidr]; exists { + delete(claims, siteId) + if len(claims) == 0 { + delete(pm.allowedIPClaims, cidr) + } + } + + // Check if this peer was the owner + owner, isOwned := pm.allowedIPOwners[cidr] + if !isOwned || owner != siteId { + return -1, false // Not the owner, nothing to promote + } + + // This peer was the owner, need to find a new owner + delete(pm.allowedIPOwners, cidr) + + // Find another claimant to promote + if claims, exists := pm.allowedIPClaims[cidr]; exists && len(claims) > 0 { + for claimantId := range claims { + pm.allowedIPOwners[cidr] = claimantId + return claimantId, true + } + } + + return -1, false +} + +// getOwnedAllowedIPs returns the list of allowed IPs that a peer currently owns in WireGuard. +// Must be called with lock held. +func (pm *PeerManager) getOwnedAllowedIPs(siteId int) []string { + var owned []string + for cidr, owner := range pm.allowedIPOwners { + if owner == siteId { + owned = append(owned, cidr) + } + } + return owned +} + +// addAllowedIp adds an IP (subnet) to the allowed IPs list of a peer +// and updates WireGuard configuration if this peer owns the IP. +// Must be called with lock held. +func (pm *PeerManager) addAllowedIp(siteId int, ip string) error { + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + // Check if IP already exists in AllowedIps + for _, allowedIp := range peer.AllowedIps { + if allowedIp == ip { + return nil // Already exists + } + } + + // Register our claim to this IP + pm.claimAllowedIP(siteId, ip) + + peer.AllowedIps = append(peer.AllowedIps, ip) + pm.peers[siteId] = peer + + // Only update WireGuard if we own this IP + if pm.allowedIPOwners[ip] == siteId { + if err := AddAllowedIP(pm.device, peer.PublicKey, ip); err != nil { + return err + } + } + + return nil +} + +// removeAllowedIp removes an IP (subnet) from the allowed IPs list of a peer +// and updates WireGuard configuration. If this peer owned the IP, it promotes +// another peer that also claims this IP. Must be called with lock held. +func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error { + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + found := false + + // Remove from AllowedIps + newAllowedIps := make([]string, 0, len(peer.AllowedIps)) + for _, allowedIp := range peer.AllowedIps { + if allowedIp == cidr { + found = true + continue + } + newAllowedIps = append(newAllowedIps, allowedIp) + } + + if !found { + return nil // Not found + } + + peer.AllowedIps = newAllowedIps + pm.peers[siteId] = peer + + // Release our claim and check if we need to promote another peer + newOwner, promoted := pm.releaseAllowedIP(siteId, cidr) + + // Build the list of IPs this peer currently owns for the replace operation + ownedIPs := pm.getOwnedAllowedIPs(siteId) + // Also include the server IP which is always owned + serverIP := strings.Split(peer.ServerIP, "/")[0] + "/32" + hasServerIP := false + for _, ip := range ownedIPs { + if ip == serverIP { + hasServerIP = true + break + } + } + if !hasServerIP { + ownedIPs = append([]string{serverIP}, ownedIPs...) + } + + // Update WireGuard for this peer using replace_allowed_ips + if err := RemoveAllowedIP(pm.device, peer.PublicKey, ownedIPs); err != nil { + return err + } + + // If another peer was promoted to owner, add the IP to their WireGuard config + if promoted && newOwner >= 0 { + if newOwnerPeer, exists := pm.peers[newOwner]; exists { + if err := AddAllowedIP(pm.device, newOwnerPeer.PublicKey, cidr); err != nil { + logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err) + } else { + logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr) + } + } + } + + return nil +} + +// AddRemoteSubnet adds an IP (subnet) to the allowed IPs list of a peer +func (pm *PeerManager) AddRemoteSubnet(siteId int, cidr string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + // Check if IP already exists in RemoteSubnets + for _, subnet := range peer.RemoteSubnets { + if subnet == cidr { + return nil // Already exists + } + } + + peer.RemoteSubnets = append(peer.RemoteSubnets, cidr) + pm.peers[siteId] = peer // Save before calling addAllowedIp which reads from pm.peers + + // Add to allowed IPs + if err := pm.addAllowedIp(siteId, cidr); err != nil { + return err + } + + // Add route + if err := network.AddRoutes([]string{cidr}, pm.interfaceName); err != nil { + return err + } + + return nil +} + +// RemoveRemoteSubnet removes an IP (subnet) from the allowed IPs list of a peer +func (pm *PeerManager) RemoveRemoteSubnet(siteId int, ip string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + found := false + + // Remove from RemoteSubnets + newSubnets := make([]string, 0, len(peer.RemoteSubnets)) + for _, subnet := range peer.RemoteSubnets { + if subnet == ip { + found = true + continue + } + newSubnets = append(newSubnets, subnet) + } + + if !found { + return nil // Not found + } + + peer.RemoteSubnets = newSubnets + pm.peers[siteId] = peer // Save before calling removeAllowedIp which reads from pm.peers + + // Remove from allowed IPs (this also handles promotion of other peers) + if err := pm.removeAllowedIp(siteId, ip); err != nil { + return err + } + + // Check if any other peer still has this subnet before removing the route + subnetStillInUse := false + for otherSiteId, otherPeer := range pm.peers { + if otherSiteId == siteId { + continue // Skip the current peer (already updated above) + } + for _, subnet := range otherPeer.RemoteSubnets { + if subnet == ip { + subnetStillInUse = true + break + } + } + if subnetStillInUse { + break + } + } + + // Only remove route if no other peer needs it + if !subnetStillInUse { + if err := network.RemoveRoutes([]string{ip}); err != nil { + return err + } + } + + return nil +} + +// AddAlias adds an alias to a peer +func (pm *PeerManager) AddAlias(siteId int, alias Alias) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + peer.Aliases = append(peer.Aliases, alias) + pm.peers[siteId] = peer + + address := net.ParseIP(alias.AliasAddress) + if address != nil { + pm.dnsProxy.AddDNSRecord(alias.Alias, address) + } + + // Add an allowed IP for the alias + if err := pm.addAllowedIp(siteId, alias.AliasAddress+"/32"); err != nil { + return err + } + + return nil +} + +// RemoveAlias removes an alias from a peer +func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + peer, exists := pm.peers[siteId] + if !exists { + return fmt.Errorf("peer with site ID %d not found", siteId) + } + + var aliasToRemove *Alias + newAliases := make([]Alias, 0, len(peer.Aliases)) + for _, a := range peer.Aliases { + if a.Alias == aliasName { + aliasToRemove = &a + continue + } + newAliases = append(newAliases, a) + } + + if aliasToRemove != nil { + address := net.ParseIP(aliasToRemove.AliasAddress) + if address != nil { + pm.dnsProxy.RemoveDNSRecord(aliasName, address) + } + } + + peer.Aliases = newAliases + pm.peers[siteId] = peer + + // Check if any other alias is still using this IP address before removing from allowed IPs + ipStillInUse := false + aliasIP := aliasToRemove.AliasAddress + "/32" + for _, a := range newAliases { + if a.AliasAddress+"/32" == aliasIP { + ipStillInUse = true + break + } + } + + // Only remove the allowed IP if no other alias is using it + if !ipStillInUse { + if err := pm.removeAllowedIp(siteId, aliasIP); err != nil { + return err + } + } + + return nil +} + +// RelayPeer handles failover to the relay server when a peer is disconnected +func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) { + pm.mu.Lock() + peer, exists := pm.peers[siteId] + if exists { + // Store the relay endpoint + peer.RelayEndpoint = relayEndpoint + pm.peers[siteId] = peer + } + pm.mu.Unlock() + + if !exists { + logger.Error("Cannot handle failover: peer with site ID %d not found", siteId) + return + } + + // Check for IPv6 and format the endpoint correctly + formattedEndpoint := relayEndpoint + if strings.Contains(relayEndpoint, ":") { + formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint) + } + + // Update only the endpoint for this peer (update_only preserves other settings) + wgConfig := fmt.Sprintf(`public_key=%s +update_only=true +endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint) + + err := pm.device.IpcSet(wgConfig) + if err != nil { + logger.Error("Failed to configure WireGuard device: %v\n", err) + return + } + + // Mark the peer as relayed in the monitor + if pm.peerMonitor != nil { + pm.peerMonitor.MarkPeerRelayed(siteId, true) + } + + logger.Info("Adjusted peer %d to point to relay!\n", siteId) +} + +// performRapidInitialTest performs a rapid holepunch test for a newly added peer. +// If the test fails, it immediately requests relay to minimize connection delay. +// This runs in a goroutine to avoid blocking AddPeer. +func (pm *PeerManager) performRapidInitialTest(siteId int, endpoint string) { + if pm.peerMonitor == nil { + return + } + + // Perform rapid test - this takes ~1-2 seconds max + holepunchViable := pm.peerMonitor.RapidTestPeer(siteId, endpoint) + + if !holepunchViable { + // Holepunch failed rapid test, request relay immediately + logger.Info("Rapid test failed for site %d, requesting relay", siteId) + if err := pm.peerMonitor.RequestRelay(siteId); err != nil { + logger.Error("Failed to request relay for site %d: %v", siteId, err) + } + } else { + logger.Info("Rapid test passed for site %d, using direct connection", siteId) + } +} + +// Start starts the peer monitor +func (pm *PeerManager) Start() { + if pm.peerMonitor != nil { + pm.peerMonitor.Start() + } +} + +// Stop stops the peer monitor +func (pm *PeerManager) Stop() { + if pm.peerMonitor != nil { + pm.peerMonitor.Stop() + } +} + +// Close stops the peer monitor and cleans up resources +func (pm *PeerManager) Close() { + if pm.peerMonitor != nil { + pm.peerMonitor.Close() + pm.peerMonitor = nil + } +} + +// MarkPeerRelayed marks a peer as currently using relay +func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) { + pm.mu.Lock() + if peer, exists := pm.peers[siteID]; exists { + if relayed { + // We're being relayed, store the current endpoint as the original + // (RelayEndpoint is set by HandleFailover) + } else { + // Clear relay endpoint when switching back to direct + peer.RelayEndpoint = "" + pm.peers[siteID] = peer + } + } + pm.mu.Unlock() + + if pm.peerMonitor != nil { + pm.peerMonitor.MarkPeerRelayed(siteID, relayed) + } +} + +// UnRelayPeer switches a peer from relay back to direct connection +func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error { + pm.mu.Lock() + peer, exists := pm.peers[siteId] + if exists { + // Store the relay endpoint + peer.Endpoint = endpoint + pm.peers[siteId] = peer + } + pm.mu.Unlock() + + if !exists { + logger.Error("Cannot handle failover: peer with site ID %d not found", siteId) + return nil + } + + // Update WireGuard to use the direct endpoint + wgConfig := fmt.Sprintf(`public_key=%s +update_only=true +endpoint=%s`, util.FixKey(peer.PublicKey), endpoint) + + err := pm.device.IpcSet(wgConfig) + if err != nil { + logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err) + return err + } + + // Mark as not relayed in monitor + if pm.peerMonitor != nil { + pm.peerMonitor.MarkPeerRelayed(siteId, false) + } + + logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint) + return nil +} diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go new file mode 100644 index 0000000..ac91cb3 --- /dev/null +++ b/peers/monitor/monitor.go @@ -0,0 +1,924 @@ +package monitor + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/api" + middleDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/websocket" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +// PeerMonitor handles monitoring the connection status to multiple WireGuard peers +type PeerMonitor struct { + monitors map[int]*Client + mutex sync.Mutex + running bool + interval time.Duration + timeout time.Duration + maxAttempts int + wsClient *websocket.Client + + // Netstack fields + middleDev *middleDevice.MiddleDevice + localIP string + stack *stack.Stack + ep *channel.Endpoint + activePorts map[uint16]bool + portsLock sync.Mutex + nsCtx context.Context + nsCancel context.CancelFunc + nsWg sync.WaitGroup + + // Holepunch testing fields + sharedBind *bind.SharedBind + holepunchTester *holepunch.HolepunchTester + holepunchInterval time.Duration + holepunchTimeout time.Duration + holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing + holepunchStatus map[int]bool // siteID -> connected status + holepunchStopChan chan struct{} + + // Relay tracking fields + relayedPeers map[int]bool // siteID -> whether the peer is currently relayed + holepunchMaxAttempts int // max consecutive failures before triggering relay + holepunchFailures map[int]int // siteID -> consecutive failure count + + // Rapid initial test fields + rapidTestInterval time.Duration // interval between rapid test attempts + rapidTestTimeout time.Duration // timeout for each rapid test attempt + rapidTestMaxAttempts int // max attempts during rapid test phase + + // API server for status updates + apiServer *api.API + + // WG connection status tracking + wgConnectionStatus map[int]bool // siteID -> WG connected status +} + +// NewPeerMonitor creates a new peer monitor with the given callback +func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { + ctx, cancel := context.WithCancel(context.Background()) + pm := &PeerMonitor{ + monitors: make(map[int]*Client), + interval: 2 * time.Second, // Default check interval (faster) + timeout: 3 * time.Second, + maxAttempts: 3, + wsClient: wsClient, + middleDev: middleDev, + localIP: localIP, + activePorts: make(map[uint16]bool), + nsCtx: ctx, + nsCancel: cancel, + sharedBind: sharedBind, + holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds + holepunchTimeout: 2 * time.Second, // Faster timeout + holepunchEndpoints: make(map[int]string), + holepunchStatus: make(map[int]bool), + relayedPeers: make(map[int]bool), + holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures + holepunchFailures: make(map[int]int), + // Rapid initial test settings: complete within ~1.5 seconds + rapidTestInterval: 200 * time.Millisecond, // 200ms between attempts + rapidTestTimeout: 400 * time.Millisecond, // 400ms timeout per attempt + rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total + apiServer: apiServer, + wgConnectionStatus: make(map[int]bool), + } + + if err := pm.initNetstack(); err != nil { + logger.Error("Failed to initialize netstack for peer monitor: %v", err) + } + + // Initialize holepunch tester if sharedBind is available + if sharedBind != nil { + pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind) + } + + return pm +} + +// SetInterval changes how frequently peers are checked +func (pm *PeerMonitor) SetInterval(interval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.interval = interval + + // Update interval for all existing monitors + for _, client := range pm.monitors { + client.SetPacketInterval(interval) + } +} + +// SetTimeout changes the timeout for waiting for responses +func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.timeout = timeout + + // Update timeout for all existing monitors + for _, client := range pm.monitors { + client.SetTimeout(timeout) + } +} + +// SetMaxAttempts changes the maximum number of attempts for TestConnection +func (pm *PeerMonitor) SetMaxAttempts(attempts int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.maxAttempts = attempts + + // Update max attempts for all existing monitors + for _, client := range pm.monitors { + client.SetMaxAttempts(attempts) + } +} + +// AddPeer adds a new peer to monitor +func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if _, exists := pm.monitors[siteID]; exists { + return nil // Already monitoring + } + + // Use our custom dialer that uses netstack + client, err := NewClient(endpoint, pm.dial) + if err != nil { + return err + } + + client.SetPacketInterval(pm.interval) + client.SetTimeout(pm.timeout) + client.SetMaxAttempts(pm.maxAttempts) + + pm.monitors[siteID] = client + + pm.holepunchEndpoints[siteID] = holepunchEndpoint + pm.holepunchStatus[siteID] = false // Initially unknown/disconnected + + if pm.running { + if err := client.StartMonitor(func(status ConnectionStatus) { + pm.handleConnectionStatusChange(siteID, status) + }); err != nil { + return err + } + } + + return nil +} + +// update holepunch endpoint for a peer +func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) { + go func() { + time.Sleep(3 * time.Second) + pm.mutex.Lock() + defer pm.mutex.Unlock() + pm.holepunchEndpoints[siteID] = endpoint + }() +} + +// RapidTestPeer performs a rapid connectivity test for a newly added peer. +// This is designed to quickly determine if holepunch is viable within ~1-2 seconds. +// Returns true if the connection is viable (holepunch works), false if it should relay. +func (pm *PeerMonitor) RapidTestPeer(siteID int, endpoint string) bool { + if pm.holepunchTester == nil { + logger.Warn("Cannot perform rapid test: holepunch tester not initialized") + return false + } + + pm.mutex.Lock() + interval := pm.rapidTestInterval + timeout := pm.rapidTestTimeout + maxAttempts := pm.rapidTestMaxAttempts + pm.mutex.Unlock() + + logger.Info("Starting rapid holepunch test for site %d at %s (max %d attempts, %v timeout each)", + siteID, endpoint, maxAttempts, timeout) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + result := pm.holepunchTester.TestEndpoint(endpoint, timeout) + + if result.Success { + logger.Info("Rapid test: site %d holepunch SUCCEEDED on attempt %d (RTT: %v)", + siteID, attempt, result.RTT) + + // Update status + pm.mutex.Lock() + pm.holepunchStatus[siteID] = true + pm.holepunchFailures[siteID] = 0 + pm.mutex.Unlock() + + return true + } + + if attempt < maxAttempts { + time.Sleep(interval) + } + } + + logger.Warn("Rapid test: site %d holepunch FAILED after %d attempts, will relay", + siteID, maxAttempts) + + // Update status to reflect failure + pm.mutex.Lock() + pm.holepunchStatus[siteID] = false + pm.holepunchFailures[siteID] = maxAttempts + pm.mutex.Unlock() + + return false +} + +// UpdatePeerEndpoint updates the monitor endpoint for a peer +func (pm *PeerMonitor) UpdatePeerEndpoint(siteID int, monitorPeer string) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + client, exists := pm.monitors[siteID] + if !exists { + logger.Warn("Cannot update endpoint: peer %d not found in monitor", siteID) + return + } + + // Update the client's server address + client.UpdateServerAddr(monitorPeer) + + logger.Info("Updated monitor endpoint for site %d to %s", siteID, monitorPeer) +} + +// removePeerUnlocked stops monitoring a peer and removes it from the monitor +// This function assumes the mutex is already held by the caller +func (pm *PeerMonitor) removePeerUnlocked(siteID int) { + client, exists := pm.monitors[siteID] + if !exists { + return + } + + client.StopMonitor() + client.Close() + delete(pm.monitors, siteID) +} + +// RemovePeer stops monitoring a peer and removes it from the monitor +func (pm *PeerMonitor) RemovePeer(siteID int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + // remove the holepunch endpoint info + delete(pm.holepunchEndpoints, siteID) + delete(pm.holepunchStatus, siteID) + delete(pm.relayedPeers, siteID) + delete(pm.holepunchFailures, siteID) + + pm.removePeerUnlocked(siteID) +} + +// Start begins monitoring all peers +func (pm *PeerMonitor) Start() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if pm.running { + return // Already running + } + + pm.running = true + + // Start monitoring all peers + for siteID, client := range pm.monitors { + siteIDCopy := siteID // Create a copy for the closure + err := client.StartMonitor(func(status ConnectionStatus) { + pm.handleConnectionStatusChange(siteIDCopy, status) + }) + if err != nil { + logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err) + continue + } + logger.Info("Started monitoring peer %d\n", siteID) + } + + pm.startHolepunchMonitor() +} + +// handleConnectionStatusChange is called when a peer's connection status changes +func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) { + pm.mutex.Lock() + previousStatus, exists := pm.wgConnectionStatus[siteID] + pm.wgConnectionStatus[siteID] = status.Connected + isRelayed := pm.relayedPeers[siteID] + endpoint := pm.holepunchEndpoints[siteID] + pm.mutex.Unlock() + + // Log status changes + if !exists || previousStatus != status.Connected { + if status.Connected { + logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT) + } else { + logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID) + } + } + + // Update API with connection status + if pm.apiServer != nil { + pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed) + } +} + +// sendRelay sends a relay message to the server +func (pm *PeerMonitor) sendRelay(siteID int) error { + if pm.wsClient == nil { + return fmt.Errorf("websocket client is nil") + } + + err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ + "siteId": siteID, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent relay message") + return nil +} + +// RequestRelay is a public method to request relay for a peer. +// This is used when rapid initial testing determines holepunch is not viable. +func (pm *PeerMonitor) RequestRelay(siteID int) error { + return pm.sendRelay(siteID) +} + +// sendUnRelay sends an unrelay message to the server +func (pm *PeerMonitor) sendUnRelay(siteID int) error { + if pm.wsClient == nil { + return fmt.Errorf("websocket client is nil") + } + + err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ + "siteId": siteID, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent unrelay message") + return nil +} + +// Stop stops monitoring all peers +func (pm *PeerMonitor) Stop() { + // Stop holepunch monitor first (outside of mutex to avoid deadlock) + pm.stopHolepunchMonitor() + + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if !pm.running { + return + } + + pm.running = false + + // Stop all monitors + for _, client := range pm.monitors { + client.StopMonitor() + } +} + +// MarkPeerRelayed marks a peer as currently using relay +func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + pm.relayedPeers[siteID] = relayed + if relayed { + // Reset failure count when marked as relayed + pm.holepunchFailures[siteID] = 0 + } +} + +// IsPeerRelayed returns whether a peer is currently using relay +func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool { + pm.mutex.Lock() + defer pm.mutex.Unlock() + return pm.relayedPeers[siteID] +} + +// startHolepunchMonitor starts the holepunch connection monitoring +// Note: This function assumes the mutex is already held by the caller (called from Start()) +func (pm *PeerMonitor) startHolepunchMonitor() error { + if pm.holepunchTester == nil { + return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)") + } + + if pm.holepunchStopChan != nil { + return fmt.Errorf("holepunch monitor already running") + } + + if err := pm.holepunchTester.Start(); err != nil { + return fmt.Errorf("failed to start holepunch tester: %w", err) + } + + pm.holepunchStopChan = make(chan struct{}) + + go pm.runHolepunchMonitor() + + logger.Info("Started holepunch connection monitor") + return nil +} + +// stopHolepunchMonitor stops the holepunch connection monitoring +func (pm *PeerMonitor) stopHolepunchMonitor() { + pm.mutex.Lock() + stopChan := pm.holepunchStopChan + pm.holepunchStopChan = nil + pm.mutex.Unlock() + + if stopChan != nil { + close(stopChan) + } + + if pm.holepunchTester != nil { + pm.holepunchTester.Stop() + } + + logger.Info("Stopped holepunch connection monitor") +} + +// runHolepunchMonitor runs the holepunch monitoring loop +func (pm *PeerMonitor) runHolepunchMonitor() { + ticker := time.NewTicker(pm.holepunchInterval) + defer ticker.Stop() + + // Do initial check immediately + pm.checkHolepunchEndpoints() + + for { + select { + case <-pm.holepunchStopChan: + return + case <-ticker.C: + pm.checkHolepunchEndpoints() + } + } +} + +// checkHolepunchEndpoints tests all holepunch endpoints +func (pm *PeerMonitor) checkHolepunchEndpoints() { + pm.mutex.Lock() + // Check if we're still running before doing any work + if !pm.running { + pm.mutex.Unlock() + return + } + endpoints := make(map[int]string, len(pm.holepunchEndpoints)) + for siteID, endpoint := range pm.holepunchEndpoints { + endpoints[siteID] = endpoint + } + timeout := pm.holepunchTimeout + maxAttempts := pm.holepunchMaxAttempts + pm.mutex.Unlock() + + for siteID, endpoint := range endpoints { + logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + result := pm.holepunchTester.TestEndpoint(endpoint, timeout) + + pm.mutex.Lock() + // Check if peer was removed while we were testing + if _, stillExists := pm.holepunchEndpoints[siteID]; !stillExists { + pm.mutex.Unlock() + continue // Peer was removed, skip processing + } + + previousStatus, exists := pm.holepunchStatus[siteID] + pm.holepunchStatus[siteID] = result.Success + isRelayed := pm.relayedPeers[siteID] + + // Track consecutive failures for relay triggering + if result.Success { + pm.holepunchFailures[siteID] = 0 + } else { + pm.holepunchFailures[siteID]++ + } + failureCount := pm.holepunchFailures[siteID] + pm.mutex.Unlock() + + // Log status changes + if !exists || previousStatus != result.Success { + if result.Success { + logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) + } else { + if result.Error != nil { + logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error) + } else { + logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint) + } + } + } + + // Update API with holepunch status + if pm.apiServer != nil { + // Update holepunch connection status + pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success) + + // Get the current WG connection status for this peer + pm.mutex.Lock() + wgConnected := pm.wgConnectionStatus[siteID] + pm.mutex.Unlock() + + // Update API - use holepunch endpoint and relay status + pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed) + } + + // Handle relay logic based on holepunch status + // Check if we're still running before sending relay messages + pm.mutex.Lock() + stillRunning := pm.running + pm.mutex.Unlock() + + if !stillRunning { + return // Stop processing if shutdown is in progress + } + + if !result.Success && !isRelayed && failureCount >= maxAttempts { + // Holepunch failed and we're not relayed - trigger relay + logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount) + if pm.wsClient != nil { + pm.sendRelay(siteID) + } + } else if result.Success && isRelayed { + // Holepunch succeeded and we ARE relayed - switch back to direct + logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID) + if pm.wsClient != nil { + pm.sendUnRelay(siteID) + } + } + } +} + +// GetHolepunchStatus returns the current holepunch status for all endpoints +func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + status := make(map[int]bool, len(pm.holepunchStatus)) + for siteID, connected := range pm.holepunchStatus { + status[siteID] = connected + } + return status +} + +// Close stops monitoring and cleans up resources +func (pm *PeerMonitor) Close() { + // Stop holepunch monitor first (outside of mutex to avoid deadlock) + pm.stopHolepunchMonitor() + + pm.mutex.Lock() + defer pm.mutex.Unlock() + + logger.Debug("PeerMonitor: Starting cleanup") + + // Stop and close all clients first + for siteID, client := range pm.monitors { + logger.Debug("PeerMonitor: Stopping client for site %d", siteID) + client.StopMonitor() + client.Close() + delete(pm.monitors, siteID) + } + + pm.running = false + + // Clean up netstack resources + logger.Debug("PeerMonitor: Cancelling netstack context") + if pm.nsCancel != nil { + pm.nsCancel() // Signal goroutines to stop + } + + // Close the channel endpoint to unblock any pending reads + logger.Debug("PeerMonitor: Closing endpoint") + if pm.ep != nil { + pm.ep.Close() + } + + // Wait for packet sender goroutine to finish with timeout + logger.Debug("PeerMonitor: Waiting for goroutines to finish") + done := make(chan struct{}) + go func() { + pm.nsWg.Wait() + close(done) + }() + + select { + case <-done: + logger.Debug("PeerMonitor: Goroutines finished cleanly") + case <-time.After(2 * time.Second): + logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway") + } + + // Destroy the stack last, after all goroutines are done + logger.Debug("PeerMonitor: Destroying stack") + if pm.stack != nil { + pm.stack.Destroy() + pm.stack = nil + } + + logger.Debug("PeerMonitor: Cleanup complete") +} + +// TestPeer tests connectivity to a specific peer +func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { + pm.mutex.Lock() + client, exists := pm.monitors[siteID] + pm.mutex.Unlock() + + if !exists { + return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) + } + + ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) + defer cancel() + + connected, rtt := client.TestConnection(ctx) + return connected, rtt, nil +} + +// TestAllPeers tests connectivity to all peers +func (pm *PeerMonitor) TestAllPeers() map[int]struct { + Connected bool + RTT time.Duration +} { + pm.mutex.Lock() + peers := make(map[int]*Client, len(pm.monitors)) + for siteID, client := range pm.monitors { + peers[siteID] = client + } + pm.mutex.Unlock() + + results := make(map[int]struct { + Connected bool + RTT time.Duration + }) + for siteID, client := range peers { + ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) + connected, rtt := client.TestConnection(ctx) + cancel() + + results[siteID] = struct { + Connected bool + RTT time.Duration + }{ + Connected: connected, + RTT: rtt, + } + } + + return results +} + +// initNetstack initializes the gvisor netstack +func (pm *PeerMonitor) initNetstack() error { + if pm.localIP == "" { + return fmt.Errorf("local IP not provided") + } + + addr, err := netip.ParseAddr(pm.localIP) + if err != nil { + return fmt.Errorf("invalid local IP: %v", err) + } + + // Create gvisor netstack + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG) + pm.stack = stack.New(stackOpts) + + // Create NIC + if err := pm.stack.CreateNIC(1, pm.ep); err != nil { + return fmt.Errorf("failed to create NIC: %v", err) + } + + // Add IP address + ipBytes := addr.As4() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), + } + + if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + return fmt.Errorf("failed to add protocol address: %v", err) + } + + // Add default route + pm.stack.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }) + + // Register filter rule on MiddleDevice + // We want to intercept packets destined to our local IP + // But ONLY if they are for ports we are listening on + pm.middleDev.AddRule(addr, pm.handlePacket) + + // Start packet sender (Stack -> WG) + pm.nsWg.Add(1) + go pm.runPacketSender() + + return nil +} + +// handlePacket is called by MiddleDevice when a packet arrives for our IP +func (pm *PeerMonitor) handlePacket(packet []byte) bool { + // Check if it's UDP + proto, ok := util.GetProtocol(packet) + if !ok || proto != 17 { // UDP + return false + } + + // Check destination port + port, ok := util.GetDestPort(packet) + if !ok { + return false + } + + // Check if we are listening on this port + pm.portsLock.Lock() + active := pm.activePorts[uint16(port)] + pm.portsLock.Unlock() + + if !active { + return false + } + + // Inject into netstack + version := packet[0] >> 4 + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + + switch version { + case 4: + pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb) + case 6: + pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb) + default: + pkb.DecRef() + return false + } + + pkb.DecRef() + return true // Handled +} + +// runPacketSender reads packets from netstack and injects them into WireGuard +func (pm *PeerMonitor) runPacketSender() { + defer pm.nsWg.Done() + logger.Debug("PeerMonitor: Packet sender goroutine started") + + // Use a ticker to periodically check for packets without blocking indefinitely + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-pm.nsCtx.Done(): + logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") + // Drain any remaining packets before exiting + for { + pkt := pm.ep.Read() + if pkt == nil { + break + } + pkt.DecRef() + } + logger.Debug("PeerMonitor: Packet sender goroutine exiting") + return + case <-ticker.C: + // Try to read packets in batches + for i := 0; i < 10; i++ { + pkt := pm.ep.Read() + if pkt == nil { + break + } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() + } + } + } +} + +// dial creates a UDP connection using the netstack +func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) { + if pm.stack == nil { + return nil, fmt.Errorf("netstack not initialized") + } + + // Parse remote address + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + // Parse local IP + localIP, err := netip.ParseAddr(pm.localIP) + if err != nil { + return nil, err + } + ipBytes := localIP.As4() + + // Create UDP connection + // We bind to port 0 (ephemeral) + laddr := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4(ipBytes), + Port: 0, + } + + raddrTcpip := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())), + Port: uint16(raddr.Port), + } + + conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + + // Get local port + localAddr := conn.LocalAddr().(*net.UDPAddr) + port := uint16(localAddr.Port) + + // Register port + pm.portsLock.Lock() + pm.activePorts[port] = true + pm.portsLock.Unlock() + + // Wrap connection to cleanup port on close + return &trackedConn{ + Conn: conn, + pm: pm, + port: port, + }, nil +} + +func (pm *PeerMonitor) removePort(port uint16) { + pm.portsLock.Lock() + delete(pm.activePorts, port) + pm.portsLock.Unlock() +} + +type trackedConn struct { + net.Conn + pm *PeerMonitor + port uint16 +} + +func (c *trackedConn) Close() error { + c.pm.removePort(c.port) + if c.Conn != nil { + return c.Conn.Close() + } + return nil +} diff --git a/wgtester/wgtester.go b/peers/monitor/wgtester.go similarity index 87% rename from wgtester/wgtester.go rename to peers/monitor/wgtester.go index 28ffdba..dac2008 100644 --- a/wgtester/wgtester.go +++ b/peers/monitor/wgtester.go @@ -1,4 +1,4 @@ -package wgtester +package monitor import ( "context" @@ -26,7 +26,7 @@ const ( // Client handles checking connectivity to a server type Client struct { - conn *net.UDPConn + conn net.Conn serverAddr string monitorRunning bool monitorLock sync.Mutex @@ -35,8 +35,12 @@ type Client struct { packetInterval time.Duration timeout time.Duration maxAttempts int + dialer Dialer } +// Dialer is a function that creates a connection +type Dialer func(network, addr string) (net.Conn, error) + // ConnectionStatus represents the current connection state type ConnectionStatus struct { Connected bool @@ -44,13 +48,14 @@ type ConnectionStatus struct { } // NewClient creates a new connection test client -func NewClient(serverAddr string) (*Client, error) { +func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ serverAddr: serverAddr, shutdownCh: make(chan struct{}), packetInterval: 2 * time.Second, timeout: 500 * time.Millisecond, // Timeout for individual packets maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } @@ -69,6 +74,20 @@ func (c *Client) SetMaxAttempts(attempts int) { c.maxAttempts = attempts } +// UpdateServerAddr updates the server address and resets the connection +func (c *Client) UpdateServerAddr(serverAddr string) { + c.connLock.Lock() + defer c.connLock.Unlock() + + // Close existing connection if any + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + + c.serverAddr = serverAddr +} + // Close cleans up client resources func (c *Client) Close() { c.StopMonitor() @@ -91,12 +110,14 @@ func (c *Client) ensureConnection() error { return nil } - serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) - if err != nil { - return err + var err error + if c.dialer != nil { + c.conn, err = c.dialer("udp", c.serverAddr) + } else { + // Fallback to standard net.Dial + c.conn, err = net.Dial("udp", c.serverAddr) } - c.conn, err = net.DialUDP("udp", nil, serverAddr) if err != nil { return err } @@ -136,14 +157,14 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { return false, 0 } - logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) + // logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } - logger.Debug("Successfully sent monitor packet") + // logger.Debug("Successfully sent monitor packet") // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) diff --git a/peers/peer.go b/peers/peer.go new file mode 100644 index 0000000..9370b9d --- /dev/null +++ b/peers/peer.go @@ -0,0 +1,142 @@ +package peers + +import ( + "fmt" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// ConfigurePeer sets up or updates a peer within the WireGuard device +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { + var endpoint string + if relay && siteConfig.RelayEndpoint != "" { + endpoint = formatEndpoint(siteConfig.RelayEndpoint) + } else { + endpoint = formatEndpoint(siteConfig.Endpoint) + } + siteHost, err := util.ResolveDomain(endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) + } + + // Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP + allowedIp := strings.Split(siteConfig.ServerIP, "/") + if len(allowedIp) > 1 { + allowedIp[1] = "32" + } else { + allowedIp = append(allowedIp, "32") + } + allowedIpStr := strings.Join(allowedIp, "/") + + // Collect all allowed IPs in a slice + var allowedIPs []string + allowedIPs = append(allowedIPs, allowedIpStr) + + // Use AllowedIps if available, otherwise fall back to RemoteSubnets for backwards compatibility + subnetsToAdd := siteConfig.AllowedIps + + // If we have anything to add, process them + if len(subnetsToAdd) > 0 { + // Add each subnet + for _, subnet := range subnetsToAdd { + subnet = strings.TrimSpace(subnet) + if subnet != "" { + allowedIPs = append(allowedIPs, subnet) + } + } + } + + // Construct WireGuard config for this peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String()))) + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey))) + + // Add each allowed IP separately + for _, allowedIP := range allowedIPs { + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) + } + + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) + configBuilder.WriteString("persistent_keepalive_interval=5\n") + + config := configBuilder.String() + logger.Debug("Configuring peer with config: %s", config) + + err = dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard peer: %v", err) + } + + return nil +} + +// RemovePeer removes a peer from the WireGuard device +func RemovePeer(dev *device.Device, siteId int, publicKey string) error { + // Construct WireGuard config to remove the peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("remove=true\n") + + config := configBuilder.String() + logger.Debug("Removing peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to remove WireGuard peer: %v", err) + } + + return nil +} + +// AddAllowedIP adds a single allowed IP to an existing peer without reconfiguring the entire peer +func AddAllowedIP(dev *device.Device, publicKey string, allowedIP string) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) + + config := configBuilder.String() + logger.Debug("Adding allowed IP to peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to add allowed IP to WireGuard peer: %v", err) + } + + return nil +} + +// RemoveAllowedIP removes a single allowed IP from an existing peer by replacing the allowed IPs list +// This requires providing all the allowed IPs that should remain after removal +func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs []string) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString("replace_allowed_ips=true\n") + + // Add each remaining allowed IP + for _, allowedIP := range remainingAllowedIPs { + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) + } + + config := configBuilder.String() + logger.Debug("Removing allowed IP from peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to remove allowed IP from WireGuard peer: %v", err) + } + + return nil +} + +func formatEndpoint(endpoint string) string { + if strings.Contains(endpoint, ":") { + return endpoint + } + return endpoint + ":51820" +} diff --git a/peers/types.go b/peers/types.go new file mode 100644 index 0000000..dab49e1 --- /dev/null +++ b/peers/types.go @@ -0,0 +1,63 @@ +package peers + +// PeerAction represents a request to add, update, or remove a peer +type PeerAction struct { + Action string `json:"action"` // "add", "update", or "remove" + SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information +} + +// UpdatePeerData represents the data needed to update a peer +type SiteConfig struct { + SiteId int `json:"siteId"` + Name string `json:"name,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + RelayEndpoint string `json:"relayEndpoint,omitempty"` + PublicKey string `json:"publicKey,omitempty"` + ServerIP string `json:"serverIP,omitempty"` + ServerPort uint16 `json:"serverPort,omitempty"` + RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access + AllowedIps []string `json:"allowedIps,omitempty"` // optional, array of allowed IPs for the peer + Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations +} + +type Alias struct { + Alias string `json:"alias"` // the alias name + AliasAddress string `json:"aliasAddress"` // the alias IP address +} + +// RemovePeer represents the data needed to remove a peer +type PeerRemove struct { + SiteId int `json:"siteId"` +} + +type RelayPeerData struct { + SiteId int `json:"siteId"` + RelayEndpoint string `json:"relayEndpoint"` +} + +type UnRelayPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` +} + +// PeerAdd represents the data needed to add remote subnets to a peer +type PeerAdd struct { + SiteId int `json:"siteId"` + RemoteSubnets []string `json:"remoteSubnets"` // subnets to add + Aliases []Alias `json:"aliases,omitempty"` // aliases to add +} + +// RemovePeerData represents the data needed to remove remote subnets from a peer +type RemovePeerData struct { + SiteId int `json:"siteId"` + RemoteSubnets []string `json:"remoteSubnets"` // subnets to remove + Aliases []Alias `json:"aliases,omitempty"` // aliases to remove +} + +type UpdatePeerData struct { + SiteId int `json:"siteId"` + OldRemoteSubnets []string `json:"oldRemoteSubnets"` // old list of remote subnets + NewRemoteSubnets []string `json:"newRemoteSubnets"` // new list of remote subnets + OldAliases []Alias `json:"oldAliases,omitempty"` // old list of aliases + NewAliases []Alias `json:"newAliases,omitempty"` // new list of aliases +} diff --git a/service_windows.go b/service_windows.go index dc941f3..48e79ce 100644 --- a/service_windows.go +++ b/service_windows.go @@ -99,15 +99,32 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes // Continue with empty args if loading fails savedArgs = []string{} } + s.elog.Info(1, fmt.Sprintf("Loaded saved service args: %v", savedArgs)) // Combine service start args with saved args, giving priority to service start args + // Note: When the service is started via SCM, args[0] is the service name + // When started via s.Start(args...), the args passed are exactly what we provide finalArgs := []string{} + + // Check if we have args passed directly to Execute (from s.Start()) if len(args) > 0 { - // Skip the first arg which is typically the service name - if len(args) > 1 { + // The first arg from SCM is the service name, but when we call s.Start(args...), + // the args we pass become args[1:] in Execute. However, if started by SCM without + // args, args[0] will be the service name. + // We need to check if args[0] looks like the service name or a flag + if len(args) == 1 && args[0] == serviceName { + // Only service name, no actual args + s.elog.Info(1, "Only service name in args, checking saved args") + } else if len(args) > 1 && args[0] == serviceName { + // Service name followed by actual args finalArgs = append(finalArgs, args[1:]...) + s.elog.Info(1, fmt.Sprintf("Using service start parameters (after service name): %v", finalArgs)) + } else { + // Args don't start with service name, use them all + // This happens when args are passed via s.Start(args...) + finalArgs = append(finalArgs, args...) + s.elog.Info(1, fmt.Sprintf("Using service start parameters (direct): %v", finalArgs)) } - s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs)) } // If no service start parameters, use saved args @@ -116,6 +133,7 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs)) } + s.elog.Info(1, fmt.Sprintf("Final args to use: %v", finalArgs)) s.args = finalArgs // Start the main olm functionality @@ -163,6 +181,9 @@ func (s *olmService) runOlm() { // Create a context that can be cancelled when the service stops s.ctx, s.stop = context.WithCancel(context.Background()) + // Create a separate context for programmatic shutdown (e.g., via API exit) + ctx, cancel := context.WithCancel(context.Background()) + // Setup logging for service mode s.elog.Info(1, "Starting Olm main logic") @@ -177,7 +198,8 @@ func (s *olmService) runOlm() { }() // Call the main olm function with stored arguments - runOlmMainWithArgs(s.ctx, s.args) + // Use s.ctx as the signal context since the service manages shutdown + runOlmMainWithArgs(ctx, cancel, s.ctx, s.args) }() // Wait for either context cancellation or main logic completion @@ -321,12 +343,15 @@ func removeService() error { } func startService(args []string) error { - // Save the service arguments as backup - if len(args) > 0 { - err := saveServiceArgs(args) - if err != nil { - return fmt.Errorf("failed to save service args: %v", err) - } + fmt.Printf("Starting service with args: %v\n", args) + + // Always save the service arguments so they can be loaded on service restart + err := saveServiceArgs(args) + if err != nil { + fmt.Printf("Warning: failed to save service args: %v\n", err) + // Continue anyway, args will still be passed directly + } else { + fmt.Printf("Saved service args to: %s\n", getServiceArgsPath()) } m, err := mgr.Connect() @@ -342,6 +367,7 @@ func startService(args []string) error { defer s.Close() // Pass arguments directly to the service start call + // Note: These args will appear in Execute() after the service name err = s.Start(args...) if err != nil { return fmt.Errorf("failed to start service: %v", err) diff --git a/unix.go b/unix.go deleted file mode 100644 index 3a9c09e..0000000 --- a/unix.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build !windows - -package main - -import ( - "net" - "os" - "strconv" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" -) - -func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { - fd, err := strconv.ParseUint(tunFdStr, 10, 32) - if err != nil { - return nil, err - } - - err = unix.SetNonblock(int(fd), true) - if err != nil { - return nil, err - } - - file := os.NewFile(uintptr(fd), "") - return tun.CreateTUNFromFile(file, mtuInt) -} -func uapiOpen(interfaceName string) (*os.File, error) { - return ipc.UAPIOpen(interfaceName) -} - -func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { - return ipc.UAPIListen(interfaceName, fileUAPI) -} diff --git a/websocket/client.go b/websocket/client.go index d1ab3da..b9f5a63 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -20,14 +20,36 @@ import ( "github.com/gorilla/websocket" ) +// AuthError represents an authentication/authorization error (401/403) +type AuthError struct { + StatusCode int + Message string +} + +func (e *AuthError) Error() string { + return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message) +} + +// IsAuthError checks if an error is an authentication error +func IsAuthError(err error) bool { + _, ok := err.(*AuthError) + return ok +} + type TokenResponse struct { Data struct { - Token string `json:"token"` + Token string `json:"token"` + ExitNodes []ExitNode `json:"exitNodes"` } `json:"data"` Success bool `json:"success"` Message string `json:"message"` } +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + type WSMessage struct { Type string `json:"type"` Data interface{} `json:"data"` @@ -39,6 +61,8 @@ type Config struct { Secret string Endpoint string TlsClientCert string // legacy PKCS12 file path + UserToken string // optional user token for websocket authentication + OrgID string // optional organization ID for websocket authentication } type Client struct { @@ -54,7 +78,8 @@ type Client struct { pingInterval time.Duration pingTimeout time.Duration onConnect func() error - onTokenUpdate func(token string) + onTokenUpdate func(token string, exitNodes []ExitNode) + onAuthError func(statusCode int, message string) // Callback for auth errors writeMux sync.Mutex clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig @@ -98,16 +123,22 @@ func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } -func (c *Client) OnTokenUpdate(callback func(token string)) { +func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) { c.onTokenUpdate = callback } +func (c *Client) OnAuthError(callback func(statusCode int, message string)) { + c.onAuthError = callback +} + // NewClient creates a new websocket client -func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { +func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ - ID: ID, - Secret: secret, - Endpoint: endpoint, + ID: ID, + Secret: secret, + Endpoint: endpoint, + UserToken: userToken, + OrgID: orgId, } client := &Client{ @@ -119,7 +150,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv isConnected: false, pingInterval: pingInterval, pingTimeout: pingTimeout, - clientType: clientType, + clientType: "olm", } // Apply options before loading config @@ -189,13 +220,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) + updateChan := make(chan interface{}) + var dataMux sync.Mutex + currentData := data + go func() { count := 0 maxAttempts := 10 - err := c.SendMessage(messageType, data) // Send immediately + err := c.SendMessage(messageType, currentData) // Send immediately if err != nil { logger.Error("Failed to send initial message: %v", err) } @@ -210,19 +245,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } - err = c.SendMessage(messageType, data) + dataMux.Lock() + err = c.SendMessage(messageType, currentData) + dataMux.Unlock() if err != nil { logger.Error("Failed to send message: %v", err) } count++ + case newData := <-updateChan: + dataMux.Lock() + // Merge newData into currentData if both are maps + if currentMap, ok := currentData.(map[string]interface{}); ok { + if newMap, ok := newData.(map[string]interface{}); ok { + // Update or add keys from newData + for key, value := range newMap { + currentMap[key] = value + } + currentData = currentMap + } else { + // If newData is not a map, replace entirely + currentData = newData + } + } else { + // If currentData is not a map, replace entirely + currentData = newData + } + dataMux.Unlock() case <-stopChan: return } } }() return func() { - close(stopChan) - } + close(stopChan) + }, func(newData interface{}) { + select { + case updateChan <- newData: + case <-stopChan: + // Channel is closed, ignore update + } + } } // RegisterHandler registers a handler for a specific message type @@ -232,11 +294,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { c.handlers[messageType] = handler } -func (c *Client) getToken() (string, error) { +func (c *Client) getToken() (string, []ExitNode, error) { // Parse the base URL to ensure we have the correct hostname baseURL, err := url.Parse(c.baseURL) if err != nil { - return "", fmt.Errorf("failed to parse base URL: %w", err) + return "", nil, fmt.Errorf("failed to parse base URL: %w", err) } // Ensure we have the base URL without trailing slashes @@ -248,7 +310,7 @@ func (c *Client) getToken() (string, error) { if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { tlsConfig, err = c.setupTLS() if err != nil { - return "", fmt.Errorf("failed to setup TLS configuration: %w", err) + return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err) } } @@ -261,24 +323,15 @@ func (c *Client) getToken() (string, error) { logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - var tokenData map[string]interface{} - - // Get a new token - if c.clientType == "newt" { - tokenData = map[string]interface{}{ - "newtId": c.config.ID, - "secret": c.config.Secret, - } - } else if c.clientType == "olm" { - tokenData = map[string]interface{}{ - "olmId": c.config.ID, - "secret": c.config.Secret, - } + tokenData := map[string]interface{}{ + "olmId": c.config.ID, + "secret": c.config.Secret, + "orgId": c.config.OrgID, } jsonData, err := json.Marshal(tokenData) if err != nil { - return "", fmt.Errorf("failed to marshal token request data: %w", err) + return "", nil, fmt.Errorf("failed to marshal token request data: %w", err) } // Create a new request @@ -288,13 +341,16 @@ func (c *Client) getToken() (string, error) { bytes.NewBuffer(jsonData), ) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return "", nil, fmt.Errorf("failed to create request: %w", err) } // Set headers req.Header.Set("Content-Type", "application/json") req.Header.Set("X-CSRF-Token", "x-csrf-protection") + // print out the request for debugging + logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + // Make the request client := &http.Client{} if tlsConfig != nil { @@ -304,33 +360,43 @@ func (c *Client) getToken() (string, error) { } resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("failed to request new token: %w", err) + return "", nil, fmt.Errorf("failed to request new token: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) - return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + + // Return AuthError for 401/403 status codes + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return "", nil, &AuthError{ + StatusCode: resp.StatusCode, + Message: string(body), + } + } + + // For other errors (5xx, network issues, etc.), return regular error + return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) } var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { logger.Error("Failed to decode token response.") - return "", fmt.Errorf("failed to decode token response: %w", err) + return "", nil, fmt.Errorf("failed to decode token response: %w", err) } if !tokenResp.Success { - return "", fmt.Errorf("failed to get token: %s", tokenResp.Message) + return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message) } if tokenResp.Data.Token == "" { - return "", fmt.Errorf("received empty token from server") + return "", nil, fmt.Errorf("received empty token from server") } logger.Debug("Received token: %s", tokenResp.Data.Token) - return tokenResp.Data.Token, nil + return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } func (c *Client) connectWithRetry() { @@ -341,6 +407,18 @@ func (c *Client) connectWithRetry() { default: err := c.establishConnection() if err != nil { + // Check if this is an auth error (401/403) + if authErr, ok := err.(*AuthError); ok { + logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) + // Trigger auth error callback if set (this should terminate the tunnel) + if c.onAuthError != nil { + c.onAuthError(authErr.StatusCode, authErr.Message) + } + // Continue retrying after auth error + time.Sleep(c.reconnectInterval) + continue + } + // For other errors (5xx, network issues), continue retrying logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue @@ -352,13 +430,13 @@ func (c *Client) connectWithRetry() { func (c *Client) establishConnection() error { // Get token for authentication - token, err := c.getToken() + token, exitNodes, err := c.getToken() if err != nil { return fmt.Errorf("failed to get token: %w", err) } if c.onTokenUpdate != nil { - c.onTokenUpdate(token) + c.onTokenUpdate(token, exitNodes) } // Parse the base URL to determine protocol and hostname @@ -384,6 +462,9 @@ func (c *Client) establishConnection() error { q := u.Query() q.Set("token", token) q.Set("clientType", c.clientType) + if c.config.UserToken != "" { + q.Set("userToken", c.config.UserToken) + } u.RawQuery = q.Encode() // Connect to WebSocket