Updated the olm client to process config vars from cli,env,file in order of precedence and persist them to file

Former-commit-id: 555c9dc9f4
This commit is contained in:
gk1
2025-10-24 17:04:48 -07:00
parent af0a72d296
commit 2d34c6c8b2
2 changed files with 535 additions and 142 deletions

185
main.go
View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"net"
"os"
@@ -19,7 +18,6 @@ import (
"github.com/fosrl/newt/websocket"
"github.com/fosrl/olm/httpserver"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -204,122 +202,41 @@ func runOlmMain(ctx context.Context) {
}
func runOlmMainWithArgs(ctx context.Context, args []string) {
// Log that we've entered the main function
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
// Load configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
config, showVersion, showConfig, err := LoadConfig(args)
if err != nil {
fmt.Printf("Failed to load configuration: %v\n", err)
return
}
// Create a new FlagSet for parsing service arguments
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
// Handle --show-config flag
if showConfig {
config.ShowConfig()
os.Exit(0)
}
// Extract commonly used values from config for convenience
var (
endpoint string
id string
secret string
mtu string
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
mtu = config.MTU
mtuInt int
dns string
logLevel = config.LogLevel
interfaceName = config.InterfaceName
enableHTTP = config.EnableHTTP
httpAddr = config.HTTPAddr
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
err error
logLevel string
interfaceName string
enableHTTP bool
httpAddr string
testMode bool // Add this var for the test flag
testTarget string // Add this var for test target
pingInterval time.Duration
pingTimeout time.Duration
doHolepunch bool
connected bool
)
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
id = os.Getenv("OLM_ID")
secret = os.Getenv("OLM_SECRET")
mtu = os.Getenv("MTU")
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
interfaceName = os.Getenv("INTERFACE")
httpAddr = os.Getenv("HTTP_ADDR")
pingIntervalStr := os.Getenv("PING_INTERVAL")
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
enableHTTPEnv := os.Getenv("ENABLE_HTTP")
holepunchEnv := os.Getenv("HOLEPUNCH")
enableHTTP = enableHTTPEnv == "true"
doHolepunch = holepunchEnv == "true"
if endpoint == "" {
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
}
if id == "" {
serviceFlags.StringVar(&id, "id", "", "Olm ID")
}
if secret == "" {
serviceFlags.StringVar(&secret, "secret", "", "Olm secret")
}
if mtu == "" {
serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use")
}
if dns == "" {
serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
}
if logLevel == "" {
serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if interfaceName == "" {
serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
}
if httpAddr == "" {
serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')")
}
if pingIntervalStr == "" {
serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
}
if pingTimeoutStr == "" {
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
}
if enableHTTPEnv == "" {
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
}
if holepunchEnv == "" {
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
}
version := serviceFlags.Bool("version", false, "Print the version")
// Parse the service arguments
if err := serviceFlags.Parse(args); err != nil {
fmt.Printf("Error parsing service arguments: %v\n", err)
return
}
// Debug: Print final values after flag parsing
// fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret)
// Parse ping intervals
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
pingInterval = 3 * time.Second
}
} else {
pingInterval = 3 * time.Second
}
if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
pingTimeout = 5 * time.Second
}
} else {
pingTimeout = 5 * time.Second
}
// Setup Windows event logging if on Windows
if runtime.GOOS == "windows" {
setupWindowsEventLog()
@@ -331,12 +248,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
olmVersion := "version_replaceme"
if *version {
if showVersion {
fmt.Println("Olm version " + olmVersion)
os.Exit(0)
} else {
logger.Info("Olm version " + olmVersion)
}
logger.Info("Olm version " + olmVersion)
if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil {
logger.Debug("Failed to check for updates: %v", err)
@@ -351,35 +267,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
// Handle test mode
if testMode {
if testTarget == "" {
logger.Fatal("Test mode requires -test-target to be set to a server:port")
}
logger.Info("Running in test mode, connecting to %s", testTarget)
// Create a new tester client
tester, err := wgtester.NewClient(testTarget)
if err != nil {
logger.Fatal("Failed to create tester client: %v", err)
}
defer tester.Close()
// Test connection with a 2-second timeout
connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second)
if connected {
logger.Info("Connection test successful! RTT: %v", rtt)
fmt.Printf("Connection test successful! RTT: %v\n", rtt)
os.Exit(0)
} else {
logger.Error("Connection test failed - no response received")
fmt.Println("Connection test failed - no response received")
os.Exit(1)
}
}
var httpServer *httpserver.HTTPServer
if enableHTTP {
httpServer = httpserver.NewHTTPServer(httpAddr)
@@ -437,9 +324,15 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
}
endpoint = olm.GetConfig().Endpoint // Update endpoint from config
id = olm.GetConfig().ID // Update ID from config
secret = olm.GetConfig().Secret // Update secret from config
// Update config with values from websocket client (which may have loaded from its config file)
config.UpdateFromWebsocket(
olm.GetConfig().ID,
olm.GetConfig().Secret,
olm.GetConfig().Endpoint,
)
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
// wait until we have a client id and secret and endpoint
waitCount := 0
@@ -974,6 +867,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
httpServer.SetConnectionStatus(true)
}
// CRITICAL: Save our full config AFTER websocket saves its limited config
// This ensures all 13 fields are preserved, not just the 4 that websocket saves
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Debug("Saved full olm config with all options")
}
if connected {
logger.Debug("Already connected, skipping registration")
return nil