From e464af5302558131ed208b32cbb6b4e437de713c Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 2 Nov 2025 18:56:09 -0800 Subject: [PATCH] Make api availble over socket --- httpserver/httpserver.go => api/api.go | 82 +++++++++++++++++++------- api/api_unix.go | 50 ++++++++++++++++ api/api_windows.go | 41 +++++++++++++ config.go | 63 ++++++++++++++------ main.go | 11 +++- olm/olm.go | 78 ++++++++++++++---------- 6 files changed, 253 insertions(+), 72 deletions(-) rename httpserver/httpserver.go => api/api.go (68%) create mode 100644 api/api_unix.go create mode 100644 api/api_windows.go diff --git a/httpserver/httpserver.go b/api/api.go similarity index 68% rename from httpserver/httpserver.go rename to api/api.go index 4f57cca..c7dfcf3 100644 --- a/httpserver/httpserver.go +++ b/api/api.go @@ -1,8 +1,9 @@ -package httpserver +package api import ( "encoding/json" "fmt" + "net" "net/http" "sync" "time" @@ -36,9 +37,11 @@ type StatusResponse struct { PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` } -// HTTPServer represents the HTTP server and its state -type HTTPServer struct { +// API represents the HTTP server and its state +type API struct { addr string + socketPath string + listener net.Listener server *http.Server connectionChan chan ConnectionRequest statusMu sync.RWMutex @@ -49,9 +52,9 @@ type HTTPServer struct { version string } -// NewHTTPServer creates a new HTTP server -func NewHTTPServer(addr string) *HTTPServer { - s := &HTTPServer{ +// NewAPI creates a new HTTP server that listens on a TCP address +func NewAPI(addr string) *API { + s := &API{ addr: addr, connectionChan: make(chan ConnectionRequest, 1), peerStatuses: make(map[int]*PeerStatus), @@ -60,20 +63,46 @@ func NewHTTPServer(addr string) *HTTPServer { return s } +// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe +func NewAPISocket(socketPath string) *API { + s := &API{ + socketPath: socketPath, + connectionChan: make(chan ConnectionRequest, 1), + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + // Start starts the HTTP server -func (s *HTTPServer) Start() error { +func (s *API) Start() error { mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) s.server = &http.Server{ - Addr: s.addr, Handler: mux, } - logger.Info("Starting HTTP server on %s", s.addr) + var err error + if s.socketPath != "" { + // Use platform-specific socket listener + s.listener, err = createSocketListener(s.socketPath) + if err != nil { + return fmt.Errorf("failed to create socket listener: %w", err) + } + logger.Info("Starting HTTP server on socket %s", s.socketPath) + } else { + // Use TCP listener + s.listener, err = net.Listen("tcp", s.addr) + if err != nil { + return fmt.Errorf("failed to create TCP listener: %w", err) + } + logger.Info("Starting HTTP server on %s", s.addr) + } + go func() { - if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed { logger.Error("HTTP server error: %v", err) } }() @@ -82,18 +111,29 @@ func (s *HTTPServer) Start() error { } // Stop stops the HTTP server -func (s *HTTPServer) Stop() error { - logger.Info("Stopping HTTP server") - return s.server.Close() +func (s *API) Stop() error { + logger.Info("Stopping api server") + + // Close the server first, which will also close the listener gracefully + if s.server != nil { + s.server.Close() + } + + // Clean up socket file if using Unix socket + if s.socketPath != "" { + cleanupSocket(s.socketPath) + } + + return nil } // GetConnectionChannel returns the channel for receiving connection requests -func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest { +func (s *API) GetConnectionChannel() <-chan ConnectionRequest { return s.connectionChan } // UpdatePeerStatus updates the status of a peer including endpoint and relay info -func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { +func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { s.statusMu.Lock() defer s.statusMu.Unlock() @@ -113,7 +153,7 @@ func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Durat } // SetConnectionStatus sets the overall connection status -func (s *HTTPServer) SetConnectionStatus(isConnected bool) { +func (s *API) SetConnectionStatus(isConnected bool) { s.statusMu.Lock() defer s.statusMu.Unlock() @@ -128,21 +168,21 @@ func (s *HTTPServer) SetConnectionStatus(isConnected bool) { } // SetTunnelIP sets the tunnel IP address -func (s *HTTPServer) SetTunnelIP(tunnelIP string) { +func (s *API) SetTunnelIP(tunnelIP string) { s.statusMu.Lock() defer s.statusMu.Unlock() s.tunnelIP = tunnelIP } // SetVersion sets the olm version -func (s *HTTPServer) SetVersion(version string) { +func (s *API) SetVersion(version string) { s.statusMu.Lock() defer s.statusMu.Unlock() s.version = version } // UpdatePeerRelayStatus updates only the relay status of a peer -func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { +func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { s.statusMu.Lock() defer s.statusMu.Unlock() @@ -159,7 +199,7 @@ func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay } // handleConnect handles the /connect endpoint -func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) { +func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return @@ -190,7 +230,7 @@ func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) { } // handleStatus handles the /status endpoint -func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { +func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return diff --git a/api/api_unix.go b/api/api_unix.go new file mode 100644 index 0000000..2dab602 --- /dev/null +++ b/api/api_unix.go @@ -0,0 +1,50 @@ +//go:build !windows +// +build !windows + +package api + +import ( + "fmt" + "net" + "os" + "path/filepath" + + "github.com/fosrl/newt/logger" +) + +// createSocketListener creates a Unix domain socket listener +func createSocketListener(socketPath string) (net.Listener, error) { + // Ensure the directory exists + dir := filepath.Dir(socketPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create socket directory: %w", err) + } + + // Remove existing socket file if it exists + if err := os.RemoveAll(socketPath); err != nil { + return nil, fmt.Errorf("failed to remove existing socket: %w", err) + } + + listener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("failed to listen on Unix socket: %w", err) + } + + // Set socket permissions to allow access + if err := os.Chmod(socketPath, 0666); err != nil { + listener.Close() + return nil, fmt.Errorf("failed to set socket permissions: %w", err) + } + + logger.Debug("Created Unix socket at %s", socketPath) + return listener, nil +} + +// cleanupSocket removes the Unix socket file +func cleanupSocket(socketPath string) { + if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove socket file %s: %v", socketPath, err) + } else { + logger.Debug("Removed Unix socket at %s", socketPath) + } +} diff --git a/api/api_windows.go b/api/api_windows.go new file mode 100644 index 0000000..d9ef373 --- /dev/null +++ b/api/api_windows.go @@ -0,0 +1,41 @@ +//go:build windows +// +build windows + +package api + +import ( + "fmt" + "net" + + "github.com/Microsoft/go-winio" + "github.com/fosrl/newt/logger" +) + +// createSocketListener creates a Windows named pipe listener +func createSocketListener(pipePath string) (net.Listener, error) { + // Ensure the pipe path has the correct format + if pipePath[0] != '\\' { + pipePath = `\\.\pipe\` + pipePath + } + + // Create a pipe configuration that allows everyone to write + config := &winio.PipeConfig{ + // Set security descriptor to allow everyone full access + // This SDDL string grants full access to Everyone (WD) and to the current owner (OW) + SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)", + } + + // Create a named pipe listener using go-winio with the configuration + listener, err := winio.ListenPipe(pipePath, config) + if err != nil { + return nil, fmt.Errorf("failed to listen on named pipe: %w", err) + } + + logger.Debug("Created named pipe at %s with write access for everyone", pipePath) + return listener, nil +} + +// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up +func cleanupSocket(pipePath string) { + logger.Debug("Named pipe %s will be automatically cleaned up", pipePath) +} diff --git a/config.go b/config.go index 0aaa9c8..191e517 100644 --- a/config.go +++ b/config.go @@ -27,8 +27,9 @@ type OlmConfig struct { LogLevel string `json:"logLevel"` // HTTP server - EnableHTTP bool `json:"enableHttp"` + EnableAPI bool `json:"enableApi"` HTTPAddr string `json:"httpAddr"` + SocketPath string `json:"socketPath"` // Ping settings PingInterval string `json:"pingInterval"` @@ -60,13 +61,22 @@ const ( // DefaultConfig returns a config with default values func DefaultConfig() *OlmConfig { + // Set OS-specific socket path + var socketPath string + switch runtime.GOOS { + case "windows": + socketPath = "olm" + default: // darwin, linux, and others + socketPath = "/var/run/olm.sock" + } + config := &OlmConfig{ MTU: 1280, DNS: "8.8.8.8", LogLevel: "INFO", InterfaceName: "olm", - EnableHTTP: false, - HTTPAddr: ":9452", + EnableAPI: false, + SocketPath: socketPath, PingInterval: "3s", PingTimeout: "5s", Holepunch: false, @@ -78,8 +88,9 @@ func DefaultConfig() *OlmConfig { config.sources["dns"] = string(SourceDefault) config.sources["logLevel"] = string(SourceDefault) config.sources["interface"] = string(SourceDefault) - config.sources["enableHttp"] = string(SourceDefault) + config.sources["enableApi"] = string(SourceDefault) config.sources["httpAddr"] = string(SourceDefault) + config.sources["socketPath"] = string(SourceDefault) config.sources["pingInterval"] = string(SourceDefault) config.sources["pingTimeout"] = string(SourceDefault) config.sources["holepunch"] = string(SourceDefault) @@ -209,9 +220,13 @@ func loadConfigFromEnv(config *OlmConfig) { config.PingTimeout = val config.sources["pingTimeout"] = string(SourceEnv) } - if val := os.Getenv("ENABLE_HTTP"); val == "true" { - config.EnableHTTP = true - config.sources["enableHttp"] = string(SourceEnv) + if val := os.Getenv("ENABLE_API"); val == "true" { + config.EnableAPI = true + config.sources["enableApi"] = string(SourceEnv) + } + if val := os.Getenv("SOCKET_PATH"); val != "" { + config.SocketPath = val + config.sources["socketPath"] = string(SourceEnv) } if val := os.Getenv("HOLEPUNCH"); val == "true" { config.Holepunch = true @@ -233,9 +248,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { "logLevel": config.LogLevel, "interface": config.InterfaceName, "httpAddr": config.HTTPAddr, + "socketPath": config.SocketPath, "pingInterval": config.PingInterval, "pingTimeout": config.PingTimeout, - "enableHttp": config.EnableHTTP, + "enableApi": config.EnableAPI, "holepunch": config.Holepunch, } @@ -248,9 +264,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface") serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')") + serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)") serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server") serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping") - serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests") + serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests") serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching") version := serviceFlags.Bool("version", false, "Print the version") @@ -286,14 +303,17 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.HTTPAddr != origValues["httpAddr"].(string) { config.sources["httpAddr"] = string(SourceCLI) } + if config.SocketPath != origValues["socketPath"].(string) { + config.sources["socketPath"] = string(SourceCLI) + } if config.PingInterval != origValues["pingInterval"].(string) { config.sources["pingInterval"] = string(SourceCLI) } if config.PingTimeout != origValues["pingTimeout"].(string) { config.sources["pingTimeout"] = string(SourceCLI) } - if config.EnableHTTP != origValues["enableHttp"].(bool) { - config.sources["enableHttp"] = string(SourceCLI) + if config.EnableAPI != origValues["enableApi"].(bool) { + config.sources["enableApi"] = string(SourceCLI) } if config.Holepunch != origValues["holepunch"].(bool) { config.sources["holepunch"] = string(SourceCLI) @@ -370,6 +390,14 @@ func mergeConfigs(dest, src *OlmConfig) { dest.HTTPAddr = src.HTTPAddr dest.sources["httpAddr"] = string(SourceFile) } + if src.SocketPath != "" { + // Check if it's not the default for any OS + isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm" + if !isDefault { + dest.SocketPath = src.SocketPath + dest.sources["socketPath"] = string(SourceFile) + } + } if src.PingInterval != "" && src.PingInterval != "3s" { dest.PingInterval = src.PingInterval dest.sources["pingInterval"] = string(SourceFile) @@ -383,9 +411,9 @@ func mergeConfigs(dest, src *OlmConfig) { dest.sources["tlsClientCert"] = string(SourceFile) } // For booleans, we always take the source value if explicitly set - if src.EnableHTTP { - dest.EnableHTTP = src.EnableHTTP - dest.sources["enableHttp"] = string(SourceFile) + if src.EnableAPI { + dest.EnableAPI = src.EnableAPI + dest.sources["enableApi"] = string(SourceFile) } if src.Holepunch { dest.Holepunch = src.Holepunch @@ -458,10 +486,11 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nLogging:") fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel")) - // HTTP server - fmt.Println("\nHTTP Server:") - fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp")) + // API server + fmt.Println("\nAPI Server:") + fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi")) fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr")) + fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath")) // Timing fmt.Println("\nTiming:") diff --git a/main.go b/main.go index d03b680..43bd5fa 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "os" + "os/signal" "runtime" + "syscall" "github.com/fosrl/newt/logger" "github.com/fosrl/olm/olm" @@ -197,8 +199,9 @@ func main() { DNS: config.DNS, InterfaceName: config.InterfaceName, LogLevel: config.LogLevel, - EnableHTTP: config.EnableHTTP, + EnableAPI: config.EnableAPI, HTTPAddr: config.HTTPAddr, + SocketPath: config.SocketPath, PingInterval: config.PingInterval, PingTimeout: config.PingTimeout, Holepunch: config.Holepunch, @@ -208,5 +211,9 @@ func main() { Version: config.Version, } - olm.Run(context.Background(), olmConfig) + // Create a context that will be cancelled on interrupt signals + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + olm.Run(ctx, olmConfig) } diff --git a/olm/olm.go b/olm/olm.go index 762bdc8..7c77f69 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -12,7 +12,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/httpserver" + "github.com/fosrl/olm/api" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" @@ -35,8 +35,9 @@ type Config struct { LogLevel string // HTTP server - EnableHTTP bool + EnableAPI bool HTTPAddr string + SocketPath string // Ping settings PingInterval string @@ -65,8 +66,6 @@ func Run(ctx context.Context, config Config) { mtu = config.MTU logLevel = config.LogLevel interfaceName = config.InterfaceName - enableHTTP = config.EnableHTTP - httpAddr = config.HTTPAddr pingInterval = config.PingIntervalDuration pingTimeout = config.PingTimeoutDuration doHolepunch = config.Holepunch @@ -92,33 +91,38 @@ func Run(ctx context.Context, config Config) { // Log startup information logger.Debug("Olm service starting...") logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) - logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr) if doHolepunch { logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") } - var httpServer *httpserver.HTTPServer - if enableHTTP { - httpServer = httpserver.NewHTTPServer(httpAddr) - httpServer.SetVersion(config.Version) - if err := httpServer.Start(); err != nil { - logger.Fatal("Failed to start HTTP server: %v", err) + var apiServer *api.API + if config.EnableAPI { + if config.HTTPAddr != "" { + apiServer = api.NewAPI(config.HTTPAddr) + } else if config.SocketPath != "" { + apiServer = api.NewAPISocket(config.SocketPath) } - // Use a goroutine to handle connection requests - go func() { - for req := range httpServer.GetConnectionChannel() { - logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - - // Set the connection parameters - id = req.ID - secret = req.Secret - endpoint = req.Endpoint - } - }() + apiServer.SetVersion(config.Version) + if err := apiServer.Start(); err != nil { + logger.Fatal("Failed to start HTTP server: %v", err) + } } + // // Use a goroutine to handle connection requests + // go func() { + // for req := range apiServer.GetConnectionChannel() { + // logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) + + // // Set the connection parameters + // id = req.ID + // secret = req.Secret + // endpoint = req.Endpoint + // } + // }() + // } + // Create a new olm olm, err := websocket.NewClient( "olm", @@ -329,13 +333,13 @@ func Run(ctx context.Context, config Config) { if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) } - if httpServer != nil { - httpServer.SetTunnelIP(wgData.TunnelIP) + if apiServer != nil { + apiServer.SetTunnelIP(wgData.TunnelIP) } peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { - if httpServer != nil { + if apiServer != nil { // Find the site config to get endpoint information var endpoint string var isRelay bool @@ -348,7 +352,7 @@ func Run(ctx context.Context, config Config) { break } } - httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) + apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) } if connected { logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) @@ -364,8 +368,8 @@ func Run(ctx context.Context, config Config) { for i := range wgData.Sites { site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice - if httpServer != nil { - httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) + if apiServer != nil { + apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) } // Format the endpoint before configuring the peer. @@ -615,8 +619,8 @@ func Run(ctx context.Context, config Config) { } // Update HTTP server to mark this peer as using relay - if httpServer != nil { - httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) + if apiServer != nil { + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) } peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) @@ -648,8 +652,8 @@ func Run(ctx context.Context, config Config) { olm.OnConnect(func() error { logger.Info("Websocket Connected") - if httpServer != nil { - httpServer.SetConnectionStatus(true) + if apiServer != nil { + apiServer.SetConnectionStatus(true) } if connected { @@ -707,10 +711,20 @@ func Run(ctx context.Context, config Config) { close(stopPing) } + if peerMonitor != nil { + peerMonitor.Stop() + } + if uapiListener != nil { uapiListener.Close() } if dev != nil { dev.Close() } + + if apiServer != nil { + apiServer.Stop() + } + + logger.Info("Olm service stopped") }