Make api availble over socket

Former-commit-id: e464af5302
This commit is contained in:
Owen
2025-11-02 18:56:09 -08:00
parent ea6fa72bc0
commit a7979259f3
6 changed files with 253 additions and 72 deletions

View File

@@ -1,8 +1,9 @@
package httpserver
package api
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -36,9 +37,11 @@ type StatusResponse struct {
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// HTTPServer represents the HTTP server and its state
type HTTPServer struct {
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
connectionChan chan ConnectionRequest
statusMu sync.RWMutex
@@ -49,9 +52,9 @@ type HTTPServer struct {
version string
}
// NewHTTPServer creates a new HTTP server
func NewHTTPServer(addr string) *HTTPServer {
s := &HTTPServer{
// NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API {
s := &API{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
@@ -60,20 +63,46 @@ func NewHTTPServer(addr string) *HTTPServer {
return s
}
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *HTTPServer) Start() error {
func (s *API) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
s.server = &http.Server{
Addr: s.addr,
Handler: mux,
}
var err error
if s.socketPath != "" {
// Use platform-specific socket listener
s.listener, err = createSocketListener(s.socketPath)
if err != nil {
return fmt.Errorf("failed to create socket listener: %w", err)
}
logger.Info("Starting HTTP server on socket %s", s.socketPath)
} else {
// Use TCP listener
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
logger.Info("Starting HTTP server on %s", s.addr)
}
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
@@ -82,18 +111,29 @@ func (s *HTTPServer) Start() error {
}
// Stop stops the HTTP server
func (s *HTTPServer) Stop() error {
logger.Info("Stopping HTTP server")
return s.server.Close()
func (s *API) Stop() error {
logger.Info("Stopping api server")
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
}
// Clean up socket file if using Unix socket
if s.socketPath != "" {
cleanupSocket(s.socketPath)
}
return nil
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
@@ -113,7 +153,7 @@ func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Durat
}
// SetConnectionStatus sets the overall connection status
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
func (s *API) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
@@ -128,21 +168,21 @@ func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
}
// SetTunnelIP sets the tunnel IP address
func (s *HTTPServer) SetTunnelIP(tunnelIP string) {
func (s *API) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *HTTPServer) SetVersion(version string) {
func (s *API) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
@@ -159,7 +199,7 @@ func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay
}
// handleConnect handles the /connect endpoint
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
@@ -190,7 +230,7 @@ func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
}
// handleStatus handles the /status endpoint
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return

50
api/api_unix.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build !windows
// +build !windows
package api
import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Unix domain socket listener
func createSocketListener(socketPath string) (net.Listener, error) {
// Ensure the directory exists
dir := filepath.Dir(socketPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil {
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
}
// Set socket permissions to allow access
if err := os.Chmod(socketPath, 0666); err != nil {
listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
logger.Debug("Created Unix socket at %s", socketPath)
return listener, nil
}
// cleanupSocket removes the Unix socket file
func cleanupSocket(socketPath string) {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
} else {
logger.Debug("Removed Unix socket at %s", socketPath)
}
}

41
api/api_windows.go Normal file
View File

@@ -0,0 +1,41 @@
//go:build windows
// +build windows
package api
import (
"fmt"
"net"
"github.com/Microsoft/go-winio"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Windows named pipe listener
func createSocketListener(pipePath string) (net.Listener, error) {
// Ensure the pipe path has the correct format
if pipePath[0] != '\\' {
pipePath = `\\.\pipe\` + pipePath
}
// Create a pipe configuration that allows everyone to write
config := &winio.PipeConfig{
// Set security descriptor to allow everyone full access
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
}
// Create a named pipe listener using go-winio with the configuration
listener, err := winio.ListenPipe(pipePath, config)
if err != nil {
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
}
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
return listener, nil
}
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
func cleanupSocket(pipePath string) {
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
}

View File

@@ -27,8 +27,9 @@ type OlmConfig struct {
LogLevel string `json:"logLevel"`
// HTTP server
EnableHTTP bool `json:"enableHttp"`
EnableAPI bool `json:"enableApi"`
HTTPAddr string `json:"httpAddr"`
SocketPath string `json:"socketPath"`
// Ping settings
PingInterval string `json:"pingInterval"`
@@ -60,13 +61,22 @@ const (
// DefaultConfig returns a config with default values
func DefaultConfig() *OlmConfig {
// Set OS-specific socket path
var socketPath string
switch runtime.GOOS {
case "windows":
socketPath = "olm"
default: // darwin, linux, and others
socketPath = "/var/run/olm.sock"
}
config := &OlmConfig{
MTU: 1280,
DNS: "8.8.8.8",
LogLevel: "INFO",
InterfaceName: "olm",
EnableHTTP: false,
HTTPAddr: ":9452",
EnableAPI: false,
SocketPath: socketPath,
PingInterval: "3s",
PingTimeout: "5s",
Holepunch: false,
@@ -78,8 +88,9 @@ func DefaultConfig() *OlmConfig {
config.sources["dns"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault)
config.sources["enableHttp"] = string(SourceDefault)
config.sources["enableApi"] = string(SourceDefault)
config.sources["httpAddr"] = string(SourceDefault)
config.sources["socketPath"] = string(SourceDefault)
config.sources["pingInterval"] = string(SourceDefault)
config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault)
@@ -209,9 +220,13 @@ func loadConfigFromEnv(config *OlmConfig) {
config.PingTimeout = val
config.sources["pingTimeout"] = string(SourceEnv)
}
if val := os.Getenv("ENABLE_HTTP"); val == "true" {
config.EnableHTTP = true
config.sources["enableHttp"] = string(SourceEnv)
if val := os.Getenv("ENABLE_API"); val == "true" {
config.EnableAPI = true
config.sources["enableApi"] = string(SourceEnv)
}
if val := os.Getenv("SOCKET_PATH"); val != "" {
config.SocketPath = val
config.sources["socketPath"] = string(SourceEnv)
}
if val := os.Getenv("HOLEPUNCH"); val == "true" {
config.Holepunch = true
@@ -233,9 +248,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
"socketPath": config.SocketPath,
"pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout,
"enableHttp": config.EnableHTTP,
"enableApi": config.EnableAPI,
"holepunch": config.Holepunch,
}
@@ -248,9 +264,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
version := serviceFlags.Bool("version", false, "Print the version")
@@ -286,14 +303,17 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.HTTPAddr != origValues["httpAddr"].(string) {
config.sources["httpAddr"] = string(SourceCLI)
}
if config.SocketPath != origValues["socketPath"].(string) {
config.sources["socketPath"] = string(SourceCLI)
}
if config.PingInterval != origValues["pingInterval"].(string) {
config.sources["pingInterval"] = string(SourceCLI)
}
if config.PingTimeout != origValues["pingTimeout"].(string) {
config.sources["pingTimeout"] = string(SourceCLI)
}
if config.EnableHTTP != origValues["enableHttp"].(bool) {
config.sources["enableHttp"] = string(SourceCLI)
if config.EnableAPI != origValues["enableApi"].(bool) {
config.sources["enableApi"] = string(SourceCLI)
}
if config.Holepunch != origValues["holepunch"].(bool) {
config.sources["holepunch"] = string(SourceCLI)
@@ -370,6 +390,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.HTTPAddr = src.HTTPAddr
dest.sources["httpAddr"] = string(SourceFile)
}
if src.SocketPath != "" {
// Check if it's not the default for any OS
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
if !isDefault {
dest.SocketPath = src.SocketPath
dest.sources["socketPath"] = string(SourceFile)
}
}
if src.PingInterval != "" && src.PingInterval != "3s" {
dest.PingInterval = src.PingInterval
dest.sources["pingInterval"] = string(SourceFile)
@@ -383,9 +411,9 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.sources["tlsClientCert"] = string(SourceFile)
}
// For booleans, we always take the source value if explicitly set
if src.EnableHTTP {
dest.EnableHTTP = src.EnableHTTP
dest.sources["enableHttp"] = string(SourceFile)
if src.EnableAPI {
dest.EnableAPI = src.EnableAPI
dest.sources["enableApi"] = string(SourceFile)
}
if src.Holepunch {
dest.Holepunch = src.Holepunch
@@ -458,10 +486,11 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nLogging:")
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
// HTTP server
fmt.Println("\nHTTP Server:")
fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp"))
// API server
fmt.Println("\nAPI Server:")
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
// Timing
fmt.Println("\nTiming:")

11
main.go
View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/olm"
@@ -197,8 +199,9 @@ func main() {
DNS: config.DNS,
InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel,
EnableHTTP: config.EnableHTTP,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
PingInterval: config.PingInterval,
PingTimeout: config.PingTimeout,
Holepunch: config.Holepunch,
@@ -208,5 +211,9 @@ func main() {
Version: config.Version,
}
olm.Run(context.Background(), olmConfig)
// Create a context that will be cancelled on interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
olm.Run(ctx, olmConfig)
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/httpserver"
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
@@ -35,8 +35,9 @@ type Config struct {
LogLevel string
// HTTP server
EnableHTTP bool
EnableAPI bool
HTTPAddr string
SocketPath string
// Ping settings
PingInterval string
@@ -65,8 +66,6 @@ func Run(ctx context.Context, config Config) {
mtu = config.MTU
logLevel = config.LogLevel
interfaceName = config.InterfaceName
enableHTTP = config.EnableHTTP
httpAddr = config.HTTPAddr
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
@@ -92,33 +91,38 @@ func Run(ctx context.Context, config Config) {
// Log startup information
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
var httpServer *httpserver.HTTPServer
if enableHTTP {
httpServer = httpserver.NewHTTPServer(httpAddr)
httpServer.SetVersion(config.Version)
if err := httpServer.Start(); err != nil {
var apiServer *api.API
if config.EnableAPI {
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
} else if config.SocketPath != "" {
apiServer = api.NewAPISocket(config.SocketPath)
}
apiServer.SetVersion(config.Version)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Use a goroutine to handle connection requests
go func() {
for req := range httpServer.GetConnectionChannel() {
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
}
}()
}
// // Use a goroutine to handle connection requests
// go func() {
// for req := range apiServer.GetConnectionChannel() {
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// // Set the connection parameters
// id = req.ID
// secret = req.Secret
// endpoint = req.Endpoint
// }
// }()
// }
// Create a new olm
olm, err := websocket.NewClient(
"olm",
@@ -329,13 +333,13 @@ func Run(ctx context.Context, config Config) {
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
if apiServer != nil {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if httpServer != nil {
if apiServer != nil {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
@@ -348,7 +352,7 @@ func Run(ctx context.Context, config Config) {
break
}
}
httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
@@ -364,8 +368,8 @@ func Run(ctx context.Context, config Config) {
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
if apiServer != nil {
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer.
@@ -615,8 +619,8 @@ func Run(ctx context.Context, config Config) {
}
// Update HTTP server to mark this peer as using relay
if httpServer != nil {
httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
if apiServer != nil {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
@@ -648,8 +652,8 @@ func Run(ctx context.Context, config Config) {
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
if httpServer != nil {
httpServer.SetConnectionStatus(true)
if apiServer != nil {
apiServer.SetConnectionStatus(true)
}
if connected {
@@ -707,10 +711,20 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
if peerMonitor != nil {
peerMonitor.Stop()
}
if uapiListener != nil {
uapiListener.Close()
}
if dev != nil {
dev.Close()
}
if apiServer != nil {
apiServer.Stop()
}
logger.Info("Olm service stopped")
}