mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Updated the olm client to process config vars from cli,env,file in order of precedence and persist them to file
Former-commit-id: 555c9dc9f4
This commit is contained in:
185
main.go
185
main.go
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/olm/httpserver"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -204,122 +202,41 @@ func runOlmMain(ctx context.Context) {
|
||||
}
|
||||
|
||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
// Log that we've entered the main function
|
||||
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
||||
// Load configuration from file, env vars, and CLI args
|
||||
// Priority: CLI args > Env vars > Config file > Defaults
|
||||
config, showVersion, showConfig, err := LoadConfig(args)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load configuration: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new FlagSet for parsing service arguments
|
||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||
// Handle --show-config flag
|
||||
if showConfig {
|
||||
config.ShowConfig()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Extract commonly used values from config for convenience
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
endpoint = config.Endpoint
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
mtu = config.MTU
|
||||
mtuInt int
|
||||
dns string
|
||||
logLevel = config.LogLevel
|
||||
interfaceName = config.InterfaceName
|
||||
enableHTTP = config.EnableHTTP
|
||||
httpAddr = config.HTTPAddr
|
||||
pingInterval = config.PingIntervalDuration
|
||||
pingTimeout = config.PingTimeoutDuration
|
||||
doHolepunch = config.Holepunch
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
interfaceName string
|
||||
enableHTTP bool
|
||||
httpAddr string
|
||||
testMode bool // Add this var for the test flag
|
||||
testTarget string // Add this var for test target
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
doHolepunch bool
|
||||
connected bool
|
||||
)
|
||||
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
id = os.Getenv("OLM_ID")
|
||||
secret = os.Getenv("OLM_SECRET")
|
||||
mtu = os.Getenv("MTU")
|
||||
dns = os.Getenv("DNS")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
httpAddr = os.Getenv("HTTP_ADDR")
|
||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
||||
enableHTTPEnv := os.Getenv("ENABLE_HTTP")
|
||||
holepunchEnv := os.Getenv("HOLEPUNCH")
|
||||
|
||||
enableHTTP = enableHTTPEnv == "true"
|
||||
doHolepunch = holepunchEnv == "true"
|
||||
|
||||
if endpoint == "" {
|
||||
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
||||
}
|
||||
if id == "" {
|
||||
serviceFlags.StringVar(&id, "id", "", "Olm ID")
|
||||
}
|
||||
if secret == "" {
|
||||
serviceFlags.StringVar(&secret, "secret", "", "Olm secret")
|
||||
}
|
||||
if mtu == "" {
|
||||
serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
||||
}
|
||||
if dns == "" {
|
||||
serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
|
||||
}
|
||||
if logLevel == "" {
|
||||
serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
|
||||
}
|
||||
if httpAddr == "" {
|
||||
serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')")
|
||||
}
|
||||
if pingIntervalStr == "" {
|
||||
serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
|
||||
}
|
||||
if pingTimeoutStr == "" {
|
||||
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
|
||||
}
|
||||
if enableHTTPEnv == "" {
|
||||
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
|
||||
}
|
||||
if holepunchEnv == "" {
|
||||
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
|
||||
}
|
||||
|
||||
version := serviceFlags.Bool("version", false, "Print the version")
|
||||
|
||||
// Parse the service arguments
|
||||
if err := serviceFlags.Parse(args); err != nil {
|
||||
fmt.Printf("Error parsing service arguments: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug: Print final values after flag parsing
|
||||
// fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret)
|
||||
|
||||
// Parse ping intervals
|
||||
if pingIntervalStr != "" {
|
||||
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
|
||||
pingInterval = 3 * time.Second
|
||||
}
|
||||
} else {
|
||||
pingInterval = 3 * time.Second
|
||||
}
|
||||
|
||||
if pingTimeoutStr != "" {
|
||||
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
|
||||
pingTimeout = 5 * time.Second
|
||||
}
|
||||
} else {
|
||||
pingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
// Setup Windows event logging if on Windows
|
||||
if runtime.GOOS == "windows" {
|
||||
setupWindowsEventLog()
|
||||
@@ -331,12 +248,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
olmVersion := "version_replaceme"
|
||||
if *version {
|
||||
if showVersion {
|
||||
fmt.Println("Olm version " + olmVersion)
|
||||
os.Exit(0)
|
||||
} else {
|
||||
logger.Info("Olm version " + olmVersion)
|
||||
}
|
||||
logger.Info("Olm version " + olmVersion)
|
||||
|
||||
if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil {
|
||||
logger.Debug("Failed to check for updates: %v", err)
|
||||
@@ -351,35 +267,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||
}
|
||||
|
||||
// Handle test mode
|
||||
if testMode {
|
||||
if testTarget == "" {
|
||||
logger.Fatal("Test mode requires -test-target to be set to a server:port")
|
||||
}
|
||||
|
||||
logger.Info("Running in test mode, connecting to %s", testTarget)
|
||||
|
||||
// Create a new tester client
|
||||
tester, err := wgtester.NewClient(testTarget)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create tester client: %v", err)
|
||||
}
|
||||
defer tester.Close()
|
||||
|
||||
// Test connection with a 2-second timeout
|
||||
connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second)
|
||||
|
||||
if connected {
|
||||
logger.Info("Connection test successful! RTT: %v", rtt)
|
||||
fmt.Printf("Connection test successful! RTT: %v\n", rtt)
|
||||
os.Exit(0)
|
||||
} else {
|
||||
logger.Error("Connection test failed - no response received")
|
||||
fmt.Println("Connection test failed - no response received")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var httpServer *httpserver.HTTPServer
|
||||
if enableHTTP {
|
||||
httpServer = httpserver.NewHTTPServer(httpAddr)
|
||||
@@ -437,9 +324,15 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create olm: %v", err)
|
||||
}
|
||||
endpoint = olm.GetConfig().Endpoint // Update endpoint from config
|
||||
id = olm.GetConfig().ID // Update ID from config
|
||||
secret = olm.GetConfig().Secret // Update secret from config
|
||||
// Update config with values from websocket client (which may have loaded from its config file)
|
||||
config.UpdateFromWebsocket(
|
||||
olm.GetConfig().ID,
|
||||
olm.GetConfig().Secret,
|
||||
olm.GetConfig().Endpoint,
|
||||
)
|
||||
endpoint = config.Endpoint
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
|
||||
// wait until we have a client id and secret and endpoint
|
||||
waitCount := 0
|
||||
@@ -974,6 +867,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
httpServer.SetConnectionStatus(true)
|
||||
}
|
||||
|
||||
// CRITICAL: Save our full config AFTER websocket saves its limited config
|
||||
// This ensures all 13 fields are preserved, not just the 4 that websocket saves
|
||||
if err := SaveConfig(config); err != nil {
|
||||
logger.Error("Failed to save full olm config: %v", err)
|
||||
} else {
|
||||
logger.Debug("Saved full olm config with all options")
|
||||
}
|
||||
|
||||
if connected {
|
||||
logger.Debug("Already connected, skipping registration")
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user