mirror of
https://github.com/fosrl/olm.git
synced 2026-03-05 02:06:48 +00:00
@@ -1,8 +1,9 @@
|
|||||||
package httpserver
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -36,9 +37,11 @@ type StatusResponse struct {
|
|||||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPServer represents the HTTP server and its state
|
// API represents the HTTP server and its state
|
||||||
type HTTPServer struct {
|
type API struct {
|
||||||
addr string
|
addr string
|
||||||
|
socketPath string
|
||||||
|
listener net.Listener
|
||||||
server *http.Server
|
server *http.Server
|
||||||
connectionChan chan ConnectionRequest
|
connectionChan chan ConnectionRequest
|
||||||
statusMu sync.RWMutex
|
statusMu sync.RWMutex
|
||||||
@@ -49,9 +52,9 @@ type HTTPServer struct {
|
|||||||
version string
|
version string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPServer creates a new HTTP server
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
func NewHTTPServer(addr string) *HTTPServer {
|
func NewAPI(addr string) *API {
|
||||||
s := &HTTPServer{
|
s := &API{
|
||||||
addr: addr,
|
addr: addr,
|
||||||
connectionChan: make(chan ConnectionRequest, 1),
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
@@ -60,20 +63,46 @@ func NewHTTPServer(addr string) *HTTPServer {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||||
|
func NewAPISocket(socketPath string) *API {
|
||||||
|
s := &API{
|
||||||
|
socketPath: socketPath,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the HTTP server
|
// Start starts the HTTP server
|
||||||
func (s *HTTPServer) Start() error {
|
func (s *API) Start() error {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/connect", s.handleConnect)
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
mux.HandleFunc("/status", s.handleStatus)
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
|
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
Addr: s.addr,
|
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Starting HTTP server on %s", s.addr)
|
var err error
|
||||||
|
if s.socketPath != "" {
|
||||||
|
// Use platform-specific socket listener
|
||||||
|
s.listener, err = createSocketListener(s.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create socket listener: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("Starting HTTP server on socket %s", s.socketPath)
|
||||||
|
} else {
|
||||||
|
// Use TCP listener
|
||||||
|
s.listener, err = net.Listen("tcp", s.addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create TCP listener: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("Starting HTTP server on %s", s.addr)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
|
||||||
logger.Error("HTTP server error: %v", err)
|
logger.Error("HTTP server error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -82,18 +111,29 @@ func (s *HTTPServer) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the HTTP server
|
// Stop stops the HTTP server
|
||||||
func (s *HTTPServer) Stop() error {
|
func (s *API) Stop() error {
|
||||||
logger.Info("Stopping HTTP server")
|
logger.Info("Stopping api server")
|
||||||
return s.server.Close()
|
|
||||||
|
// Close the server first, which will also close the listener gracefully
|
||||||
|
if s.server != nil {
|
||||||
|
s.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up socket file if using Unix socket
|
||||||
|
if s.socketPath != "" {
|
||||||
|
cleanupSocket(s.socketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnectionChannel returns the channel for receiving connection requests
|
// GetConnectionChannel returns the channel for receiving connection requests
|
||||||
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
|
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
||||||
return s.connectionChan
|
return s.connectionChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
defer s.statusMu.Unlock()
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
@@ -113,7 +153,7 @@ func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Durat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetConnectionStatus sets the overall connection status
|
// SetConnectionStatus sets the overall connection status
|
||||||
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
|
func (s *API) SetConnectionStatus(isConnected bool) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
defer s.statusMu.Unlock()
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
@@ -128,21 +168,21 @@ func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetTunnelIP sets the tunnel IP address
|
// SetTunnelIP sets the tunnel IP address
|
||||||
func (s *HTTPServer) SetTunnelIP(tunnelIP string) {
|
func (s *API) SetTunnelIP(tunnelIP string) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
defer s.statusMu.Unlock()
|
defer s.statusMu.Unlock()
|
||||||
s.tunnelIP = tunnelIP
|
s.tunnelIP = tunnelIP
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetVersion sets the olm version
|
// SetVersion sets the olm version
|
||||||
func (s *HTTPServer) SetVersion(version string) {
|
func (s *API) SetVersion(version string) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
defer s.statusMu.Unlock()
|
defer s.statusMu.Unlock()
|
||||||
s.version = version
|
s.version = version
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||||
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
defer s.statusMu.Unlock()
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
@@ -159,7 +199,7 @@ func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleConnect handles the /connect endpoint
|
// handleConnect handles the /connect endpoint
|
||||||
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
|
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
@@ -190,7 +230,7 @@ func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleStatus handles the /status endpoint
|
// handleStatus handles the /status endpoint
|
||||||
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
50
api/api_unix.go
Normal file
50
api/api_unix.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSocketListener creates a Unix domain socket listener
|
||||||
|
func createSocketListener(socketPath string) (net.Listener, error) {
|
||||||
|
// Ensure the directory exists
|
||||||
|
dir := filepath.Dir(socketPath)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create socket directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove existing socket file if it exists
|
||||||
|
if err := os.RemoveAll(socketPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set socket permissions to allow access
|
||||||
|
if err := os.Chmod(socketPath, 0666); err != nil {
|
||||||
|
listener.Close()
|
||||||
|
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Created Unix socket at %s", socketPath)
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSocket removes the Unix socket file
|
||||||
|
func cleanupSocket(socketPath string) {
|
||||||
|
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Removed Unix socket at %s", socketPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
api/api_windows.go
Normal file
41
api/api_windows.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio"
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSocketListener creates a Windows named pipe listener
|
||||||
|
func createSocketListener(pipePath string) (net.Listener, error) {
|
||||||
|
// Ensure the pipe path has the correct format
|
||||||
|
if pipePath[0] != '\\' {
|
||||||
|
pipePath = `\\.\pipe\` + pipePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pipe configuration that allows everyone to write
|
||||||
|
config := &winio.PipeConfig{
|
||||||
|
// Set security descriptor to allow everyone full access
|
||||||
|
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
|
||||||
|
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a named pipe listener using go-winio with the configuration
|
||||||
|
listener, err := winio.ListenPipe(pipePath, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
|
||||||
|
func cleanupSocket(pipePath string) {
|
||||||
|
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
|
||||||
|
}
|
||||||
63
config.go
63
config.go
@@ -27,8 +27,9 @@ type OlmConfig struct {
|
|||||||
LogLevel string `json:"logLevel"`
|
LogLevel string `json:"logLevel"`
|
||||||
|
|
||||||
// HTTP server
|
// HTTP server
|
||||||
EnableHTTP bool `json:"enableHttp"`
|
EnableAPI bool `json:"enableApi"`
|
||||||
HTTPAddr string `json:"httpAddr"`
|
HTTPAddr string `json:"httpAddr"`
|
||||||
|
SocketPath string `json:"socketPath"`
|
||||||
|
|
||||||
// Ping settings
|
// Ping settings
|
||||||
PingInterval string `json:"pingInterval"`
|
PingInterval string `json:"pingInterval"`
|
||||||
@@ -60,13 +61,22 @@ const (
|
|||||||
|
|
||||||
// DefaultConfig returns a config with default values
|
// DefaultConfig returns a config with default values
|
||||||
func DefaultConfig() *OlmConfig {
|
func DefaultConfig() *OlmConfig {
|
||||||
|
// Set OS-specific socket path
|
||||||
|
var socketPath string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
socketPath = "olm"
|
||||||
|
default: // darwin, linux, and others
|
||||||
|
socketPath = "/var/run/olm.sock"
|
||||||
|
}
|
||||||
|
|
||||||
config := &OlmConfig{
|
config := &OlmConfig{
|
||||||
MTU: 1280,
|
MTU: 1280,
|
||||||
DNS: "8.8.8.8",
|
DNS: "8.8.8.8",
|
||||||
LogLevel: "INFO",
|
LogLevel: "INFO",
|
||||||
InterfaceName: "olm",
|
InterfaceName: "olm",
|
||||||
EnableHTTP: false,
|
EnableAPI: false,
|
||||||
HTTPAddr: ":9452",
|
SocketPath: socketPath,
|
||||||
PingInterval: "3s",
|
PingInterval: "3s",
|
||||||
PingTimeout: "5s",
|
PingTimeout: "5s",
|
||||||
Holepunch: false,
|
Holepunch: false,
|
||||||
@@ -78,8 +88,9 @@ func DefaultConfig() *OlmConfig {
|
|||||||
config.sources["dns"] = string(SourceDefault)
|
config.sources["dns"] = string(SourceDefault)
|
||||||
config.sources["logLevel"] = string(SourceDefault)
|
config.sources["logLevel"] = string(SourceDefault)
|
||||||
config.sources["interface"] = string(SourceDefault)
|
config.sources["interface"] = string(SourceDefault)
|
||||||
config.sources["enableHttp"] = string(SourceDefault)
|
config.sources["enableApi"] = string(SourceDefault)
|
||||||
config.sources["httpAddr"] = string(SourceDefault)
|
config.sources["httpAddr"] = string(SourceDefault)
|
||||||
|
config.sources["socketPath"] = string(SourceDefault)
|
||||||
config.sources["pingInterval"] = string(SourceDefault)
|
config.sources["pingInterval"] = string(SourceDefault)
|
||||||
config.sources["pingTimeout"] = string(SourceDefault)
|
config.sources["pingTimeout"] = string(SourceDefault)
|
||||||
config.sources["holepunch"] = string(SourceDefault)
|
config.sources["holepunch"] = string(SourceDefault)
|
||||||
@@ -209,9 +220,13 @@ func loadConfigFromEnv(config *OlmConfig) {
|
|||||||
config.PingTimeout = val
|
config.PingTimeout = val
|
||||||
config.sources["pingTimeout"] = string(SourceEnv)
|
config.sources["pingTimeout"] = string(SourceEnv)
|
||||||
}
|
}
|
||||||
if val := os.Getenv("ENABLE_HTTP"); val == "true" {
|
if val := os.Getenv("ENABLE_API"); val == "true" {
|
||||||
config.EnableHTTP = true
|
config.EnableAPI = true
|
||||||
config.sources["enableHttp"] = string(SourceEnv)
|
config.sources["enableApi"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("SOCKET_PATH"); val != "" {
|
||||||
|
config.SocketPath = val
|
||||||
|
config.sources["socketPath"] = string(SourceEnv)
|
||||||
}
|
}
|
||||||
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
||||||
config.Holepunch = true
|
config.Holepunch = true
|
||||||
@@ -233,9 +248,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
"logLevel": config.LogLevel,
|
"logLevel": config.LogLevel,
|
||||||
"interface": config.InterfaceName,
|
"interface": config.InterfaceName,
|
||||||
"httpAddr": config.HTTPAddr,
|
"httpAddr": config.HTTPAddr,
|
||||||
|
"socketPath": config.SocketPath,
|
||||||
"pingInterval": config.PingInterval,
|
"pingInterval": config.PingInterval,
|
||||||
"pingTimeout": config.PingTimeout,
|
"pingTimeout": config.PingTimeout,
|
||||||
"enableHttp": config.EnableHTTP,
|
"enableApi": config.EnableAPI,
|
||||||
"holepunch": config.Holepunch,
|
"holepunch": config.Holepunch,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,9 +264,10 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||||
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||||
|
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
|
||||||
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||||
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||||
serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests")
|
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||||
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
||||||
|
|
||||||
version := serviceFlags.Bool("version", false, "Print the version")
|
version := serviceFlags.Bool("version", false, "Print the version")
|
||||||
@@ -286,14 +303,17 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
||||||
config.sources["httpAddr"] = string(SourceCLI)
|
config.sources["httpAddr"] = string(SourceCLI)
|
||||||
}
|
}
|
||||||
|
if config.SocketPath != origValues["socketPath"].(string) {
|
||||||
|
config.sources["socketPath"] = string(SourceCLI)
|
||||||
|
}
|
||||||
if config.PingInterval != origValues["pingInterval"].(string) {
|
if config.PingInterval != origValues["pingInterval"].(string) {
|
||||||
config.sources["pingInterval"] = string(SourceCLI)
|
config.sources["pingInterval"] = string(SourceCLI)
|
||||||
}
|
}
|
||||||
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
||||||
config.sources["pingTimeout"] = string(SourceCLI)
|
config.sources["pingTimeout"] = string(SourceCLI)
|
||||||
}
|
}
|
||||||
if config.EnableHTTP != origValues["enableHttp"].(bool) {
|
if config.EnableAPI != origValues["enableApi"].(bool) {
|
||||||
config.sources["enableHttp"] = string(SourceCLI)
|
config.sources["enableApi"] = string(SourceCLI)
|
||||||
}
|
}
|
||||||
if config.Holepunch != origValues["holepunch"].(bool) {
|
if config.Holepunch != origValues["holepunch"].(bool) {
|
||||||
config.sources["holepunch"] = string(SourceCLI)
|
config.sources["holepunch"] = string(SourceCLI)
|
||||||
@@ -370,6 +390,14 @@ func mergeConfigs(dest, src *OlmConfig) {
|
|||||||
dest.HTTPAddr = src.HTTPAddr
|
dest.HTTPAddr = src.HTTPAddr
|
||||||
dest.sources["httpAddr"] = string(SourceFile)
|
dest.sources["httpAddr"] = string(SourceFile)
|
||||||
}
|
}
|
||||||
|
if src.SocketPath != "" {
|
||||||
|
// Check if it's not the default for any OS
|
||||||
|
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
|
||||||
|
if !isDefault {
|
||||||
|
dest.SocketPath = src.SocketPath
|
||||||
|
dest.sources["socketPath"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
if src.PingInterval != "" && src.PingInterval != "3s" {
|
if src.PingInterval != "" && src.PingInterval != "3s" {
|
||||||
dest.PingInterval = src.PingInterval
|
dest.PingInterval = src.PingInterval
|
||||||
dest.sources["pingInterval"] = string(SourceFile)
|
dest.sources["pingInterval"] = string(SourceFile)
|
||||||
@@ -383,9 +411,9 @@ func mergeConfigs(dest, src *OlmConfig) {
|
|||||||
dest.sources["tlsClientCert"] = string(SourceFile)
|
dest.sources["tlsClientCert"] = string(SourceFile)
|
||||||
}
|
}
|
||||||
// For booleans, we always take the source value if explicitly set
|
// For booleans, we always take the source value if explicitly set
|
||||||
if src.EnableHTTP {
|
if src.EnableAPI {
|
||||||
dest.EnableHTTP = src.EnableHTTP
|
dest.EnableAPI = src.EnableAPI
|
||||||
dest.sources["enableHttp"] = string(SourceFile)
|
dest.sources["enableApi"] = string(SourceFile)
|
||||||
}
|
}
|
||||||
if src.Holepunch {
|
if src.Holepunch {
|
||||||
dest.Holepunch = src.Holepunch
|
dest.Holepunch = src.Holepunch
|
||||||
@@ -458,10 +486,11 @@ func (c *OlmConfig) ShowConfig() {
|
|||||||
fmt.Println("\nLogging:")
|
fmt.Println("\nLogging:")
|
||||||
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
||||||
|
|
||||||
// HTTP server
|
// API server
|
||||||
fmt.Println("\nHTTP Server:")
|
fmt.Println("\nAPI Server:")
|
||||||
fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp"))
|
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
|
||||||
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
||||||
|
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
|
||||||
|
|
||||||
// Timing
|
// Timing
|
||||||
fmt.Println("\nTiming:")
|
fmt.Println("\nTiming:")
|
||||||
|
|||||||
11
main.go
11
main.go
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/olm"
|
"github.com/fosrl/olm/olm"
|
||||||
@@ -197,8 +199,9 @@ func main() {
|
|||||||
DNS: config.DNS,
|
DNS: config.DNS,
|
||||||
InterfaceName: config.InterfaceName,
|
InterfaceName: config.InterfaceName,
|
||||||
LogLevel: config.LogLevel,
|
LogLevel: config.LogLevel,
|
||||||
EnableHTTP: config.EnableHTTP,
|
EnableAPI: config.EnableAPI,
|
||||||
HTTPAddr: config.HTTPAddr,
|
HTTPAddr: config.HTTPAddr,
|
||||||
|
SocketPath: config.SocketPath,
|
||||||
PingInterval: config.PingInterval,
|
PingInterval: config.PingInterval,
|
||||||
PingTimeout: config.PingTimeout,
|
PingTimeout: config.PingTimeout,
|
||||||
Holepunch: config.Holepunch,
|
Holepunch: config.Holepunch,
|
||||||
@@ -208,5 +211,9 @@ func main() {
|
|||||||
Version: config.Version,
|
Version: config.Version,
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.Run(context.Background(), olmConfig)
|
// Create a context that will be cancelled on interrupt signals
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
olm.Run(ctx, olmConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
78
olm/olm.go
78
olm/olm.go
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/updates"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/olm/httpserver"
|
"github.com/fosrl/olm/api"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -35,8 +35,9 @@ type Config struct {
|
|||||||
LogLevel string
|
LogLevel string
|
||||||
|
|
||||||
// HTTP server
|
// HTTP server
|
||||||
EnableHTTP bool
|
EnableAPI bool
|
||||||
HTTPAddr string
|
HTTPAddr string
|
||||||
|
SocketPath string
|
||||||
|
|
||||||
// Ping settings
|
// Ping settings
|
||||||
PingInterval string
|
PingInterval string
|
||||||
@@ -65,8 +66,6 @@ func Run(ctx context.Context, config Config) {
|
|||||||
mtu = config.MTU
|
mtu = config.MTU
|
||||||
logLevel = config.LogLevel
|
logLevel = config.LogLevel
|
||||||
interfaceName = config.InterfaceName
|
interfaceName = config.InterfaceName
|
||||||
enableHTTP = config.EnableHTTP
|
|
||||||
httpAddr = config.HTTPAddr
|
|
||||||
pingInterval = config.PingIntervalDuration
|
pingInterval = config.PingIntervalDuration
|
||||||
pingTimeout = config.PingTimeoutDuration
|
pingTimeout = config.PingTimeoutDuration
|
||||||
doHolepunch = config.Holepunch
|
doHolepunch = config.Holepunch
|
||||||
@@ -92,33 +91,38 @@ func Run(ctx context.Context, config Config) {
|
|||||||
// Log startup information
|
// Log startup information
|
||||||
logger.Debug("Olm service starting...")
|
logger.Debug("Olm service starting...")
|
||||||
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
|
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
|
||||||
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
|
|
||||||
|
|
||||||
if doHolepunch {
|
if doHolepunch {
|
||||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var httpServer *httpserver.HTTPServer
|
var apiServer *api.API
|
||||||
if enableHTTP {
|
if config.EnableAPI {
|
||||||
httpServer = httpserver.NewHTTPServer(httpAddr)
|
if config.HTTPAddr != "" {
|
||||||
httpServer.SetVersion(config.Version)
|
apiServer = api.NewAPI(config.HTTPAddr)
|
||||||
if err := httpServer.Start(); err != nil {
|
} else if config.SocketPath != "" {
|
||||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
apiServer = api.NewAPISocket(config.SocketPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a goroutine to handle connection requests
|
apiServer.SetVersion(config.Version)
|
||||||
go func() {
|
if err := apiServer.Start(); err != nil {
|
||||||
for req := range httpServer.GetConnectionChannel() {
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
}
|
||||||
|
|
||||||
// Set the connection parameters
|
|
||||||
id = req.ID
|
|
||||||
secret = req.Secret
|
|
||||||
endpoint = req.Endpoint
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// // Use a goroutine to handle connection requests
|
||||||
|
// go func() {
|
||||||
|
// for req := range apiServer.GetConnectionChannel() {
|
||||||
|
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||||
|
|
||||||
|
// // Set the connection parameters
|
||||||
|
// id = req.ID
|
||||||
|
// secret = req.Secret
|
||||||
|
// endpoint = req.Endpoint
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
// }
|
||||||
|
|
||||||
// Create a new olm
|
// Create a new olm
|
||||||
olm, err := websocket.NewClient(
|
olm, err := websocket.NewClient(
|
||||||
"olm",
|
"olm",
|
||||||
@@ -329,13 +333,13 @@ func Run(ctx context.Context, config Config) {
|
|||||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
}
|
}
|
||||||
if httpServer != nil {
|
if apiServer != nil {
|
||||||
httpServer.SetTunnelIP(wgData.TunnelIP)
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
if httpServer != nil {
|
if apiServer != nil {
|
||||||
// Find the site config to get endpoint information
|
// Find the site config to get endpoint information
|
||||||
var endpoint string
|
var endpoint string
|
||||||
var isRelay bool
|
var isRelay bool
|
||||||
@@ -348,7 +352,7 @@ func Run(ctx context.Context, config Config) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
||||||
}
|
}
|
||||||
if connected {
|
if connected {
|
||||||
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||||
@@ -364,8 +368,8 @@ func Run(ctx context.Context, config Config) {
|
|||||||
|
|
||||||
for i := range wgData.Sites {
|
for i := range wgData.Sites {
|
||||||
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
||||||
if httpServer != nil {
|
if apiServer != nil {
|
||||||
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format the endpoint before configuring the peer.
|
// Format the endpoint before configuring the peer.
|
||||||
@@ -615,8 +619,8 @@ func Run(ctx context.Context, config Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update HTTP server to mark this peer as using relay
|
// Update HTTP server to mark this peer as using relay
|
||||||
if httpServer != nil {
|
if apiServer != nil {
|
||||||
httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
@@ -648,8 +652,8 @@ func Run(ctx context.Context, config Config) {
|
|||||||
olm.OnConnect(func() error {
|
olm.OnConnect(func() error {
|
||||||
logger.Info("Websocket Connected")
|
logger.Info("Websocket Connected")
|
||||||
|
|
||||||
if httpServer != nil {
|
if apiServer != nil {
|
||||||
httpServer.SetConnectionStatus(true)
|
apiServer.SetConnectionStatus(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
@@ -707,10 +711,20 @@ func Run(ctx context.Context, config Config) {
|
|||||||
close(stopPing)
|
close(stopPing)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peerMonitor != nil {
|
||||||
|
peerMonitor.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
if uapiListener != nil {
|
if uapiListener != nil {
|
||||||
uapiListener.Close()
|
uapiListener.Close()
|
||||||
}
|
}
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
dev.Close()
|
dev.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if apiServer != nil {
|
||||||
|
apiServer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Olm service stopped")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user