Merge branch 'env-vars' into dev

This commit is contained in:
Owen Schwartz
2025-01-14 23:45:12 -05:00
2 changed files with 66 additions and 34 deletions

View File

@@ -4,7 +4,11 @@ set -e
# first arg is `-f` or `--some-option` # first arg is `-f` or `--some-option`
if [ "${1#-}" != "$1" ]; then if [ "${1#-}" != "$1" ]; then
<<<<<<< HEAD
set -- gerbil "$@" set -- gerbil "$@"
=======
set -- newt "$@"
>>>>>>> env-vars
fi fi
exec "$@" exec "$@"

96
main.go
View File

@@ -21,9 +21,9 @@ import (
) )
var ( var (
interfaceName = "wg0" interfaceName string
listenAddr = ":3003" listenAddr string
mtuInt = 1420 mtuInt int
lastReadings = make(map[string]PeerReading) lastReadings = make(map[string]PeerReading)
mu sync.Mutex mu sync.Mutex
) )
@@ -74,58 +74,86 @@ func parseLogLevel(level string) logger.LogLevel {
} }
func main() { func main() {
var err error var (
var wgconfig WgConfig err error
wgconfig WgConfig
configFile string
remoteConfigURL string
reportBandwidthTo string
generateAndSaveKeyTo string
reachableAt string
logLevel string
mtu string
)
// Define command line flags interfaceName = os.Getenv("INTERFACE")
interfaceNameArg := flag.String("interface", "wg0", "Name of the WireGuard interface") configFile = os.Getenv("CONFIG")
mtu := flag.String("mtu", "1280", "MTU of the interface") remoteConfigURL = os.Getenv("REMOTE_CONFIG")
configFile := flag.String("config", "", "Path to local configuration file") listenAddr = os.Getenv("LISTEN")
remoteConfigURL := flag.String("remoteConfig", "", "URL to fetch remote configuration") reportBandwidthTo = os.Getenv("REPORT_BANDWIDTH_TO")
listenAddrArg := flag.String("listen", ":3003", "Address to listen on") generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reportBandwidthTo := flag.String("reportBandwidthTo", "", "Address to listen on") reachableAt = os.Getenv("REACHABLE_AT")
generateAndSaveKeyTo := flag.String("generateAndSaveKeyTo", "", "Path to save generated private key") logLevel = os.Getenv("LOG_LEVEL")
reachableAt := flag.String("reachableAt", "", "Endpoint of the http server to tell remote config about") mtu = os.Getenv("MTU")
logLevel := flag.String("log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
if configFile == "" {
flag.StringVar(&configFile, "config", "", "Path to local configuration file")
}
if remoteConfigURL == "" {
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
}
if listenAddr == "" {
flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on")
}
if reportBandwidthTo == "" {
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1420", "MTU of the WireGuard interface")
}
flag.Parse() flag.Parse()
logger.Init() logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(*logLevel)) logger.GetLogger().SetLevel(parseLogLevel(logLevel))
if *interfaceNameArg != "" { mtuInt, err = strconv.Atoi(mtu)
interfaceName = *interfaceNameArg
}
if *listenAddrArg != "" {
listenAddr = *listenAddrArg
}
mtuInt, err = strconv.Atoi(*mtu)
if err != nil { if err != nil {
logger.Fatal("Failed to parse MTU: %v", err) logger.Fatal("Failed to parse MTU: %v", err)
} }
// Validate that only one config option is provided // Validate that only one config option is provided
if (*configFile != "" && *remoteConfigURL != "") || (*configFile == "" && *remoteConfigURL == "") { if (configFile != "" && remoteConfigURL != "") || (configFile == "" && remoteConfigURL == "") {
logger.Fatal("Please provide either --config or --remoteConfig, but not both") logger.Fatal("Please provide either --config or --remoteConfig, but not both")
} }
var key wgtypes.Key var key wgtypes.Key
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if *generateAndSaveKeyTo != "" { if generateAndSaveKeyTo != "" {
if _, err := os.Stat(*generateAndSaveKeyTo); os.IsNotExist(err) { if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key // generate a new private key
key, err = wgtypes.GeneratePrivateKey() key, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
} }
// save the key to the file // save the key to the file
err = os.WriteFile(*generateAndSaveKeyTo, []byte(key.String()), 0644) err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil { if err != nil {
logger.Fatal("Failed to save private key: %v", err) logger.Fatal("Failed to save private key: %v", err)
} }
} else { } else {
keyData, err := os.ReadFile(*generateAndSaveKeyTo) keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil { if err != nil {
logger.Fatal("Failed to read private key: %v", err) logger.Fatal("Failed to read private key: %v", err)
} }
@@ -146,8 +174,8 @@ func main() {
} }
// Load configuration based on provided argument // Load configuration based on provided argument
if *configFile != "" { if configFile != "" {
wgconfig, err = loadConfig(*configFile) wgconfig, err = loadConfig(configFile)
if err != nil { if err != nil {
logger.Fatal("Failed to load configuration: %v", err) logger.Fatal("Failed to load configuration: %v", err)
} }
@@ -155,7 +183,7 @@ func main() {
wgconfig.PrivateKey = key.String() wgconfig.PrivateKey = key.String()
} }
} else { } else {
wgconfig, err = loadRemoteConfig(*remoteConfigURL, key, *reachableAt) wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
if err != nil { if err != nil {
logger.Fatal("Failed to load configuration: %v", err) logger.Fatal("Failed to load configuration: %v", err)
} }
@@ -176,8 +204,8 @@ func main() {
// Ensure the WireGuard peers exist // Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers) ensureWireguardPeers(wgconfig.Peers)
if *reportBandwidthTo != "" { if reportBandwidthTo != "" {
go periodicBandwidthCheck(*reportBandwidthTo) go periodicBandwidthCheck(reportBandwidthTo)
} }
http.HandleFunc("/peer", handlePeer) http.HandleFunc("/peer", handlePeer)