Make api availble over socket

Former-commit-id: e464af5302
This commit is contained in:
Owen
2025-11-02 18:56:09 -08:00
parent ea6fa72bc0
commit a7979259f3
6 changed files with 253 additions and 72 deletions

View File

@@ -1,8 +1,9 @@
package httpserver package api
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -36,9 +37,11 @@ type StatusResponse struct {
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
} }
// HTTPServer represents the HTTP server and its state // API represents the HTTP server and its state
type HTTPServer struct { type API struct {
addr string addr string
socketPath string
listener net.Listener
server *http.Server server *http.Server
connectionChan chan ConnectionRequest connectionChan chan ConnectionRequest
statusMu sync.RWMutex statusMu sync.RWMutex
@@ -49,9 +52,9 @@ type HTTPServer struct {
version string version string
} }
// NewHTTPServer creates a new HTTP server // NewAPI creates a new HTTP server that listens on a TCP address
func NewHTTPServer(addr string) *HTTPServer { func NewAPI(addr string) *API {
s := &HTTPServer{ s := &API{
addr: addr, addr: addr,
connectionChan: make(chan ConnectionRequest, 1), connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus), peerStatuses: make(map[int]*PeerStatus),
@@ -60,20 +63,46 @@ func NewHTTPServer(addr string) *HTTPServer {
return s return s
} }
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server // Start starts the HTTP server
func (s *HTTPServer) Start() error { func (s *API) Start() error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/status", s.handleStatus)
s.server = &http.Server{ s.server = &http.Server{
Addr: s.addr,
Handler: mux, Handler: mux,
} }
logger.Info("Starting HTTP server on %s", s.addr) var err error
if s.socketPath != "" {
// Use platform-specific socket listener
s.listener, err = createSocketListener(s.socketPath)
if err != nil {
return fmt.Errorf("failed to create socket listener: %w", err)
}
logger.Info("Starting HTTP server on socket %s", s.socketPath)
} else {
// Use TCP listener
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
logger.Info("Starting HTTP server on %s", s.addr)
}
go func() { go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err) logger.Error("HTTP server error: %v", err)
} }
}() }()
@@ -82,18 +111,29 @@ func (s *HTTPServer) Start() error {
} }
// Stop stops the HTTP server // Stop stops the HTTP server
func (s *HTTPServer) Stop() error { func (s *API) Stop() error {
logger.Info("Stopping HTTP server") logger.Info("Stopping api server")
return s.server.Close()
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
}
// Clean up socket file if using Unix socket
if s.socketPath != "" {
cleanupSocket(s.socketPath)
}
return nil
} }
// GetConnectionChannel returns the channel for receiving connection requests // GetConnectionChannel returns the channel for receiving connection requests
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest { func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan return s.connectionChan
} }
// UpdatePeerStatus updates the status of a peer including endpoint and relay info // UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
@@ -113,7 +153,7 @@ func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Durat
} }
// SetConnectionStatus sets the overall connection status // SetConnectionStatus sets the overall connection status
func (s *HTTPServer) SetConnectionStatus(isConnected bool) { func (s *API) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
@@ -128,21 +168,21 @@ func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
} }
// SetTunnelIP sets the tunnel IP address // SetTunnelIP sets the tunnel IP address
func (s *HTTPServer) SetTunnelIP(tunnelIP string) { func (s *API) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP s.tunnelIP = tunnelIP
} }
// SetVersion sets the olm version // SetVersion sets the olm version
func (s *HTTPServer) SetVersion(version string) { func (s *API) SetVersion(version string) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
s.version = version s.version = version
} }
// UpdatePeerRelayStatus updates only the relay status of a peer // UpdatePeerRelayStatus updates only the relay status of a peer
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
@@ -159,7 +199,7 @@ func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay
} }
// handleConnect handles the /connect endpoint // handleConnect handles the /connect endpoint
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) { func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
@@ -190,7 +230,7 @@ func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
} }
// handleStatus handles the /status endpoint // handleStatus handles the /status endpoint
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return

50
api/api_unix.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build !windows
// +build !windows
package api
import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Unix domain socket listener
func createSocketListener(socketPath string) (net.Listener, error) {
// Ensure the directory exists
dir := filepath.Dir(socketPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil {
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
}
// Set socket permissions to allow access
if err := os.Chmod(socketPath, 0666); err != nil {
listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
logger.Debug("Created Unix socket at %s", socketPath)
return listener, nil
}
// cleanupSocket removes the Unix socket file
func cleanupSocket(socketPath string) {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
} else {
logger.Debug("Removed Unix socket at %s", socketPath)
}
}

41
api/api_windows.go Normal file
View File

@@ -0,0 +1,41 @@
//go:build windows
// +build windows
package api
import (
"fmt"
"net"
"github.com/Microsoft/go-winio"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Windows named pipe listener
func createSocketListener(pipePath string) (net.Listener, error) {
// Ensure the pipe path has the correct format
if pipePath[0] != '\\' {
pipePath = `\\.\pipe\` + pipePath
}
// Create a pipe configuration that allows everyone to write
config := &winio.PipeConfig{
// Set security descriptor to allow everyone full access
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
}
// Create a named pipe listener using go-winio with the configuration
listener, err := winio.ListenPipe(pipePath, config)
if err != nil {
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
}
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
return listener, nil
}
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
func cleanupSocket(pipePath string) {
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
}

View File

@@ -27,8 +27,9 @@ type OlmConfig struct {
LogLevel string `json:"logLevel"` LogLevel string `json:"logLevel"`
// HTTP server // HTTP server
EnableHTTP bool `json:"enableHttp"` EnableAPI bool `json:"enableApi"`
HTTPAddr string `json:"httpAddr"` HTTPAddr string `json:"httpAddr"`
SocketPath string `json:"socketPath"`
// Ping settings // Ping settings
PingInterval string `json:"pingInterval"` PingInterval string `json:"pingInterval"`
@@ -60,13 +61,22 @@ const (
// DefaultConfig returns a config with default values // DefaultConfig returns a config with default values
func DefaultConfig() *OlmConfig { func DefaultConfig() *OlmConfig {
// Set OS-specific socket path
var socketPath string
switch runtime.GOOS {
case "windows":
socketPath = "olm"
default: // darwin, linux, and others
socketPath = "/var/run/olm.sock"
}
config := &OlmConfig{ config := &OlmConfig{
MTU: 1280, MTU: 1280,
DNS: "8.8.8.8", DNS: "8.8.8.8",
LogLevel: "INFO", LogLevel: "INFO",
InterfaceName: "olm", InterfaceName: "olm",
EnableHTTP: false, EnableAPI: false,
HTTPAddr: ":9452", SocketPath: socketPath,
PingInterval: "3s", PingInterval: "3s",
PingTimeout: "5s", PingTimeout: "5s",
Holepunch: false, Holepunch: false,
@@ -78,8 +88,9 @@ func DefaultConfig() *OlmConfig {
config.sources["dns"] = string(SourceDefault) config.sources["dns"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault) config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault) config.sources["interface"] = string(SourceDefault)
config.sources["enableHttp"] = string(SourceDefault) config.sources["enableApi"] = string(SourceDefault)
config.sources["httpAddr"] = string(SourceDefault) config.sources["httpAddr"] = string(SourceDefault)
config.sources["socketPath"] = string(SourceDefault)
config.sources["pingInterval"] = string(SourceDefault) config.sources["pingInterval"] = string(SourceDefault)
config.sources["pingTimeout"] = string(SourceDefault) config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault) config.sources["holepunch"] = string(SourceDefault)
@@ -209,9 +220,13 @@ func loadConfigFromEnv(config *OlmConfig) {
config.PingTimeout = val config.PingTimeout = val
config.sources["pingTimeout"] = string(SourceEnv) config.sources["pingTimeout"] = string(SourceEnv)
} }
if val := os.Getenv("ENABLE_HTTP"); val == "true" { if val := os.Getenv("ENABLE_API"); val == "true" {
config.EnableHTTP = true config.EnableAPI = true
config.sources["enableHttp"] = string(SourceEnv) config.sources["enableApi"] = string(SourceEnv)
}
if val := os.Getenv("SOCKET_PATH"); val != "" {
config.SocketPath = val
config.sources["socketPath"] = string(SourceEnv)
} }
if val := os.Getenv("HOLEPUNCH"); val == "true" { if val := os.Getenv("HOLEPUNCH"); val == "true" {
config.Holepunch = true config.Holepunch = true
@@ -233,9 +248,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"logLevel": config.LogLevel, "logLevel": config.LogLevel,
"interface": config.InterfaceName, "interface": config.InterfaceName,
"httpAddr": config.HTTPAddr, "httpAddr": config.HTTPAddr,
"socketPath": config.SocketPath,
"pingInterval": config.PingInterval, "pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout, "pingTimeout": config.PingTimeout,
"enableHttp": config.EnableHTTP, "enableApi": config.EnableAPI,
"holepunch": config.Holepunch, "holepunch": config.Holepunch,
} }
@@ -248,9 +264,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface") serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')") serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server") serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping") serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests") serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching") serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
version := serviceFlags.Bool("version", false, "Print the version") version := serviceFlags.Bool("version", false, "Print the version")
@@ -286,14 +303,17 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.HTTPAddr != origValues["httpAddr"].(string) { if config.HTTPAddr != origValues["httpAddr"].(string) {
config.sources["httpAddr"] = string(SourceCLI) config.sources["httpAddr"] = string(SourceCLI)
} }
if config.SocketPath != origValues["socketPath"].(string) {
config.sources["socketPath"] = string(SourceCLI)
}
if config.PingInterval != origValues["pingInterval"].(string) { if config.PingInterval != origValues["pingInterval"].(string) {
config.sources["pingInterval"] = string(SourceCLI) config.sources["pingInterval"] = string(SourceCLI)
} }
if config.PingTimeout != origValues["pingTimeout"].(string) { if config.PingTimeout != origValues["pingTimeout"].(string) {
config.sources["pingTimeout"] = string(SourceCLI) config.sources["pingTimeout"] = string(SourceCLI)
} }
if config.EnableHTTP != origValues["enableHttp"].(bool) { if config.EnableAPI != origValues["enableApi"].(bool) {
config.sources["enableHttp"] = string(SourceCLI) config.sources["enableApi"] = string(SourceCLI)
} }
if config.Holepunch != origValues["holepunch"].(bool) { if config.Holepunch != origValues["holepunch"].(bool) {
config.sources["holepunch"] = string(SourceCLI) config.sources["holepunch"] = string(SourceCLI)
@@ -370,6 +390,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.HTTPAddr = src.HTTPAddr dest.HTTPAddr = src.HTTPAddr
dest.sources["httpAddr"] = string(SourceFile) dest.sources["httpAddr"] = string(SourceFile)
} }
if src.SocketPath != "" {
// Check if it's not the default for any OS
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
if !isDefault {
dest.SocketPath = src.SocketPath
dest.sources["socketPath"] = string(SourceFile)
}
}
if src.PingInterval != "" && src.PingInterval != "3s" { if src.PingInterval != "" && src.PingInterval != "3s" {
dest.PingInterval = src.PingInterval dest.PingInterval = src.PingInterval
dest.sources["pingInterval"] = string(SourceFile) dest.sources["pingInterval"] = string(SourceFile)
@@ -383,9 +411,9 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.sources["tlsClientCert"] = string(SourceFile) dest.sources["tlsClientCert"] = string(SourceFile)
} }
// For booleans, we always take the source value if explicitly set // For booleans, we always take the source value if explicitly set
if src.EnableHTTP { if src.EnableAPI {
dest.EnableHTTP = src.EnableHTTP dest.EnableAPI = src.EnableAPI
dest.sources["enableHttp"] = string(SourceFile) dest.sources["enableApi"] = string(SourceFile)
} }
if src.Holepunch { if src.Holepunch {
dest.Holepunch = src.Holepunch dest.Holepunch = src.Holepunch
@@ -458,10 +486,11 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nLogging:") fmt.Println("\nLogging:")
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel")) fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
// HTTP server // API server
fmt.Println("\nHTTP Server:") fmt.Println("\nAPI Server:")
fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp")) fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr")) fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
// Timing // Timing
fmt.Println("\nTiming:") fmt.Println("\nTiming:")

11
main.go
View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/signal"
"runtime" "runtime"
"syscall"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/olm" "github.com/fosrl/olm/olm"
@@ -197,8 +199,9 @@ func main() {
DNS: config.DNS, DNS: config.DNS,
InterfaceName: config.InterfaceName, InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel, LogLevel: config.LogLevel,
EnableHTTP: config.EnableHTTP, EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr, HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
PingInterval: config.PingInterval, PingInterval: config.PingInterval,
PingTimeout: config.PingTimeout, PingTimeout: config.PingTimeout,
Holepunch: config.Holepunch, Holepunch: config.Holepunch,
@@ -208,5 +211,9 @@ func main() {
Version: config.Version, Version: config.Version,
} }
olm.Run(context.Background(), olmConfig) // Create a context that will be cancelled on interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
olm.Run(ctx, olmConfig)
} }

View File

@@ -12,7 +12,7 @@ import (
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates" "github.com/fosrl/newt/updates"
"github.com/fosrl/olm/httpserver" "github.com/fosrl/olm/api"
"github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket" "github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
@@ -35,8 +35,9 @@ type Config struct {
LogLevel string LogLevel string
// HTTP server // HTTP server
EnableHTTP bool EnableAPI bool
HTTPAddr string HTTPAddr string
SocketPath string
// Ping settings // Ping settings
PingInterval string PingInterval string
@@ -65,8 +66,6 @@ func Run(ctx context.Context, config Config) {
mtu = config.MTU mtu = config.MTU
logLevel = config.LogLevel logLevel = config.LogLevel
interfaceName = config.InterfaceName interfaceName = config.InterfaceName
enableHTTP = config.EnableHTTP
httpAddr = config.HTTPAddr
pingInterval = config.PingIntervalDuration pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch doHolepunch = config.Holepunch
@@ -92,33 +91,38 @@ func Run(ctx context.Context, config Config) {
// Log startup information // Log startup information
logger.Debug("Olm service starting...") logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
if doHolepunch { if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
} }
var httpServer *httpserver.HTTPServer var apiServer *api.API
if enableHTTP { if config.EnableAPI {
httpServer = httpserver.NewHTTPServer(httpAddr) if config.HTTPAddr != "" {
httpServer.SetVersion(config.Version) apiServer = api.NewAPI(config.HTTPAddr)
if err := httpServer.Start(); err != nil { } else if config.SocketPath != "" {
logger.Fatal("Failed to start HTTP server: %v", err) apiServer = api.NewAPISocket(config.SocketPath)
} }
// Use a goroutine to handle connection requests apiServer.SetVersion(config.Version)
go func() { if err := apiServer.Start(); err != nil {
for req := range httpServer.GetConnectionChannel() { logger.Fatal("Failed to start HTTP server: %v", err)
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) }
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
}
}()
} }
// // Use a goroutine to handle connection requests
// go func() {
// for req := range apiServer.GetConnectionChannel() {
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// // Set the connection parameters
// id = req.ID
// secret = req.Secret
// endpoint = req.Endpoint
// }
// }()
// }
// Create a new olm // Create a new olm
olm, err := websocket.NewClient( olm, err := websocket.NewClient(
"olm", "olm",
@@ -329,13 +333,13 @@ func Run(ctx context.Context, config Config) {
if err = ConfigureInterface(interfaceName, wgData); err != nil { if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err) logger.Error("Failed to configure interface: %v", err)
} }
if httpServer != nil { if apiServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP) apiServer.SetTunnelIP(wgData.TunnelIP)
} }
peerMonitor = peermonitor.NewPeerMonitor( peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) { func(siteID int, connected bool, rtt time.Duration) {
if httpServer != nil { if apiServer != nil {
// Find the site config to get endpoint information // Find the site config to get endpoint information
var endpoint string var endpoint string
var isRelay bool var isRelay bool
@@ -348,7 +352,7 @@ func Run(ctx context.Context, config Config) {
break break
} }
} }
httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
} }
if connected { if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
@@ -364,8 +368,8 @@ func Run(ctx context.Context, config Config) {
for i := range wgData.Sites { for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil { if apiServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
} }
// Format the endpoint before configuring the peer. // Format the endpoint before configuring the peer.
@@ -615,8 +619,8 @@ func Run(ctx context.Context, config Config) {
} }
// Update HTTP server to mark this peer as using relay // Update HTTP server to mark this peer as using relay
if httpServer != nil { if apiServer != nil {
httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
} }
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
@@ -648,8 +652,8 @@ func Run(ctx context.Context, config Config) {
olm.OnConnect(func() error { olm.OnConnect(func() error {
logger.Info("Websocket Connected") logger.Info("Websocket Connected")
if httpServer != nil { if apiServer != nil {
httpServer.SetConnectionStatus(true) apiServer.SetConnectionStatus(true)
} }
if connected { if connected {
@@ -707,10 +711,20 @@ func Run(ctx context.Context, config Config) {
close(stopPing) close(stopPing)
} }
if peerMonitor != nil {
peerMonitor.Stop()
}
if uapiListener != nil { if uapiListener != nil {
uapiListener.Close() uapiListener.Close()
} }
if dev != nil { if dev != nil {
dev.Close() dev.Close()
} }
if apiServer != nil {
apiServer.Stop()
}
logger.Info("Olm service stopped")
} }