diff --git a/entrypoint.sh b/entrypoint.sh index 5058133..23310b2 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -4,7 +4,11 @@ set -e # first arg is `-f` or `--some-option` if [ "${1#-}" != "$1" ]; then +<<<<<<< HEAD set -- gerbil "$@" +======= + set -- newt "$@" +>>>>>>> env-vars fi exec "$@" \ No newline at end of file diff --git a/main.go b/main.go index e049f21..6fbfd24 100644 --- a/main.go +++ b/main.go @@ -21,9 +21,9 @@ import ( ) var ( - interfaceName = "wg0" - listenAddr = ":3003" - mtuInt = 1420 + interfaceName string + listenAddr string + mtuInt int lastReadings = make(map[string]PeerReading) mu sync.Mutex ) @@ -74,58 +74,86 @@ func parseLogLevel(level string) logger.LogLevel { } func main() { - var err error - var wgconfig WgConfig + var ( + err error + wgconfig WgConfig + configFile string + remoteConfigURL string + reportBandwidthTo string + generateAndSaveKeyTo string + reachableAt string + logLevel string + mtu string + ) - // Define command line flags - interfaceNameArg := flag.String("interface", "wg0", "Name of the WireGuard interface") - mtu := flag.String("mtu", "1280", "MTU of the interface") - configFile := flag.String("config", "", "Path to local configuration file") - remoteConfigURL := flag.String("remoteConfig", "", "URL to fetch remote configuration") - listenAddrArg := flag.String("listen", ":3003", "Address to listen on") - reportBandwidthTo := flag.String("reportBandwidthTo", "", "Address to listen on") - generateAndSaveKeyTo := flag.String("generateAndSaveKeyTo", "", "Path to save generated private key") - reachableAt := flag.String("reachableAt", "", "Endpoint of the http server to tell remote config about") - logLevel := flag.String("log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + interfaceName = os.Getenv("INTERFACE") + configFile = os.Getenv("CONFIG") + remoteConfigURL = os.Getenv("REMOTE_CONFIG") + listenAddr = os.Getenv("LISTEN") + reportBandwidthTo = os.Getenv("REPORT_BANDWIDTH_TO") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + reachableAt = os.Getenv("REACHABLE_AT") + logLevel = os.Getenv("LOG_LEVEL") + mtu = os.Getenv("MTU") + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") + } + if configFile == "" { + flag.StringVar(&configFile, "config", "", "Path to local configuration file") + } + if remoteConfigURL == "" { + flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration") + } + if listenAddr == "" { + flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on") + } + if reportBandwidthTo == "" { + flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") + } + if reachableAt == "" { + flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") + } + if logLevel == "" { + flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + } + if mtu == "" { + flag.StringVar(&mtu, "mtu", "1420", "MTU of the WireGuard interface") + } flag.Parse() logger.Init() - logger.GetLogger().SetLevel(parseLogLevel(*logLevel)) + logger.GetLogger().SetLevel(parseLogLevel(logLevel)) - if *interfaceNameArg != "" { - interfaceName = *interfaceNameArg - } - if *listenAddrArg != "" { - listenAddr = *listenAddrArg - } - - mtuInt, err = strconv.Atoi(*mtu) + mtuInt, err = strconv.Atoi(mtu) if err != nil { logger.Fatal("Failed to parse MTU: %v", err) } // Validate that only one config option is provided - if (*configFile != "" && *remoteConfigURL != "") || (*configFile == "" && *remoteConfigURL == "") { + if (configFile != "" && remoteConfigURL != "") || (configFile == "" && remoteConfigURL == "") { logger.Fatal("Please provide either --config or --remoteConfig, but not both") } var key wgtypes.Key // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - if *generateAndSaveKeyTo != "" { - if _, err := os.Stat(*generateAndSaveKeyTo); os.IsNotExist(err) { + if generateAndSaveKeyTo != "" { + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { // generate a new private key key, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) } // save the key to the file - err = os.WriteFile(*generateAndSaveKeyTo, []byte(key.String()), 0644) + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) if err != nil { logger.Fatal("Failed to save private key: %v", err) } } else { - keyData, err := os.ReadFile(*generateAndSaveKeyTo) + keyData, err := os.ReadFile(generateAndSaveKeyTo) if err != nil { logger.Fatal("Failed to read private key: %v", err) } @@ -146,8 +174,8 @@ func main() { } // Load configuration based on provided argument - if *configFile != "" { - wgconfig, err = loadConfig(*configFile) + if configFile != "" { + wgconfig, err = loadConfig(configFile) if err != nil { logger.Fatal("Failed to load configuration: %v", err) } @@ -155,7 +183,7 @@ func main() { wgconfig.PrivateKey = key.String() } } else { - wgconfig, err = loadRemoteConfig(*remoteConfigURL, key, *reachableAt) + wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt) if err != nil { logger.Fatal("Failed to load configuration: %v", err) } @@ -176,8 +204,8 @@ func main() { // Ensure the WireGuard peers exist ensureWireguardPeers(wgconfig.Peers) - if *reportBandwidthTo != "" { - go periodicBandwidthCheck(*reportBandwidthTo) + if reportBandwidthTo != "" { + go periodicBandwidthCheck(reportBandwidthTo) } http.HandleFunc("/peer", handlePeer)