diff --git a/go.mod b/go.mod index 26bbe23..36ffe8f 100644 --- a/go.mod +++ b/go.mod @@ -2,18 +2,21 @@ module github.com/fosrl/gerbil go 1.21.5 +require ( + github.com/vishvananda/netlink v1.3.0 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 +) + require ( github.com/google/go-cmp v0.5.9 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect - github.com/vishvananda/netlink v1.3.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/crypto v0.8.0 // indirect golang.org/x/net v0.9.0 // indirect golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect ) diff --git a/logger/level.go b/logger/level.go new file mode 100644 index 0000000..175995f --- /dev/null +++ b/logger/level.go @@ -0,0 +1,27 @@ +package logger + +type LogLevel int + +const ( + DEBUG LogLevel = iota + INFO + WARN + ERROR + FATAL +) + +var levelStrings = map[LogLevel]string{ + DEBUG: "DEBUG", + INFO: "INFO", + WARN: "WARN", + ERROR: "ERROR", + FATAL: "FATAL", +} + +// String returns the string representation of the log level +func (l LogLevel) String() string { + if s, ok := levelStrings[l]; ok { + return s + } + return "UNKNOWN" +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..9ef486d --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,106 @@ +package logger + +import ( + "fmt" + "log" + "os" + "sync" + "time" +) + +// Logger struct holds the logger instance +type Logger struct { + logger *log.Logger + level LogLevel +} + +var ( + defaultLogger *Logger + once sync.Once +) + +// NewLogger creates a new logger instance +func NewLogger() *Logger { + return &Logger{ + logger: log.New(os.Stdout, "", 0), + level: DEBUG, + } +} + +// Init initializes the default logger +func Init() *Logger { + once.Do(func() { + defaultLogger = NewLogger() + }) + return defaultLogger +} + +// GetLogger returns the default logger instance +func GetLogger() *Logger { + if defaultLogger == nil { + Init() + } + return defaultLogger +} + +// SetLevel sets the minimum logging level +func (l *Logger) SetLevel(level LogLevel) { + l.level = level +} + +// log handles the actual logging +func (l *Logger) log(level LogLevel, format string, args ...interface{}) { + if level < l.level { + return + } + timestamp := time.Now().Format("2006/01/02 15:04:05") + message := fmt.Sprintf(format, args...) + l.logger.Printf("%s: %s %s", level.String(), timestamp, message) +} + +// Debug logs debug level messages +func (l *Logger) Debug(format string, args ...interface{}) { + l.log(DEBUG, format, args...) +} + +// Info logs info level messages +func (l *Logger) Info(format string, args ...interface{}) { + l.log(INFO, format, args...) +} + +// Warn logs warning level messages +func (l *Logger) Warn(format string, args ...interface{}) { + l.log(WARN, format, args...) +} + +// Error logs error level messages +func (l *Logger) Error(format string, args ...interface{}) { + l.log(ERROR, format, args...) +} + +// Fatal logs fatal level messages and exits +func (l *Logger) Fatal(format string, args ...interface{}) { + l.log(FATAL, format, args...) + os.Exit(1) +} + +// Global helper functions +func Debug(format string, args ...interface{}) { + GetLogger().Debug(format, args...) +} + +func Info(format string, args ...interface{}) { + GetLogger().Info(format, args...) +} + +func Warn(format string, args ...interface{}) { + GetLogger().Warn(format, args...) +} + +func Error(format string, args ...interface{}) { + GetLogger().Error(format, args...) +} + +func Fatal(format string, args ...interface{}) { + GetLogger().Fatal(format, args...) +} diff --git a/main.go b/main.go index 115357a..64a4946 100644 --- a/main.go +++ b/main.go @@ -6,13 +6,14 @@ import ( "flag" "fmt" "io" - "log" "net" "net/http" "os" + "strings" "sync" "time" + "github.com/fosrl/gerbil/logger" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -53,6 +54,23 @@ var ( wgClient *wgctrl.Client ) +func parseLogLevel(level string) logger.LogLevel { + switch strings.ToUpper(level) { + case "DEBUG": + return logger.DEBUG + case "INFO": + return logger.INFO + case "WARN": + return logger.WARN + case "ERROR": + return logger.ERROR + case "FATAL": + return logger.FATAL + default: + return logger.INFO // default to INFO if invalid level provided + } +} + func main() { var err error var wgconfig WgConfig @@ -65,8 +83,13 @@ func main() { reportBandwidthTo := flag.String("reportBandwidthTo", "", "Address to listen on") generateAndSaveKeyTo := flag.String("generateAndSaveKeyTo", "", "Path to save generated private key") reachableAt := flag.String("reachableAt", "", "Endpoint of the http server to tell remote config about") + logLevel := flag.String("log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + flag.Parse() + logger.Init() + logger.GetLogger().SetLevel(parseLogLevel(*logLevel)) + if *interfaceNameArg != "" { interfaceName = *interfaceNameArg } @@ -76,7 +99,7 @@ func main() { // Validate that only one config option is provided if (*configFile != "" && *remoteConfigURL != "") || (*configFile == "" && *remoteConfigURL == "") { - log.Fatal("Please provide either --config or --remoteConfig, but not both") + logger.Fatal("Please provide either --config or --remoteConfig, but not both") } var key wgtypes.Key @@ -86,21 +109,21 @@ func main() { // generate a new private key key, err = wgtypes.GeneratePrivateKey() if err != nil { - log.Fatalf("Failed to generate private key: %v", err) + logger.Fatal("Failed to generate private key: %v", err) } // save the key to the file err = os.WriteFile(*generateAndSaveKeyTo, []byte(key.String()), 0644) if err != nil { - log.Fatalf("Failed to save private key: %v", err) + logger.Fatal("Failed to save private key: %v", err) } } else { keyData, err := os.ReadFile(*generateAndSaveKeyTo) if err != nil { - log.Fatalf("Failed to read private key: %v", err) + logger.Fatal("Failed to read private key: %v", err) } key, err = wgtypes.ParseKey(string(keyData)) if err != nil { - log.Fatalf("Failed to parse private key: %v", err) + logger.Fatal("Failed to parse private key: %v", err) } } } else { @@ -109,7 +132,7 @@ func main() { // generate a new one key, err = wgtypes.GeneratePrivateKey() if err != nil { - log.Fatalf("Failed to generate private key: %v", err) + logger.Fatal("Failed to generate private key: %v", err) } } } @@ -118,7 +141,7 @@ func main() { if *configFile != "" { wgconfig, err = loadConfig(*configFile) if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + logger.Fatal("Failed to load configuration: %v", err) } if wgconfig.PrivateKey == "" { wgconfig.PrivateKey = key.String() @@ -126,20 +149,20 @@ func main() { } else { wgconfig, err = loadRemoteConfig(*remoteConfigURL, key, *reachableAt) if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + logger.Fatal("Failed to load configuration: %v", err) } wgconfig.PrivateKey = key.String() } wgClient, err = wgctrl.New() if err != nil { - log.Fatalf("Failed to create WireGuard client: %v", err) + logger.Fatal("Failed to create WireGuard client: %v", err) } defer wgClient.Close() // Ensure the WireGuard interface exists and is configured if err := ensureWireguardInterface(wgconfig); err != nil { - log.Fatalf("Failed to ensure WireGuard interface: %v", err) + logger.Fatal("Failed to ensure WireGuard interface: %v", err) } // Ensure the WireGuard peers exist @@ -150,8 +173,8 @@ func main() { } http.HandleFunc("/peer", handlePeer) - log.Printf("Starting server on %s", listenAddr) - log.Fatal(http.ListenAndServe(listenAddr, nil)) + logger.Info("Starting server on %s", listenAddr) + logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil)) } func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { @@ -164,7 +187,7 @@ func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig resp, err := http.Post(url, "application/json", body) if err != nil { // print the error - fmt.Println("Error fetching remote config:", err) + logger.Error("Error fetching remote config %s: %v", url, err) return WgConfig{}, err } defer resp.Body.Close() @@ -184,7 +207,7 @@ func loadConfig(filename string) (WgConfig, error) { // Open the JSON file file, err := os.Open(filename) if err != nil { - fmt.Println("Error opening file:", err) + logger.Error("Error opening file %s: %v", filename, err) return WgConfig{}, err } defer file.Close() @@ -192,7 +215,7 @@ func loadConfig(filename string) (WgConfig, error) { // Read the file contents byteValue, err := io.ReadAll(file) if err != nil { - fmt.Println("Error reading file:", err) + logger.Error("Error reading file %s: %v", filename, err) return WgConfig{}, err } @@ -202,7 +225,7 @@ func loadConfig(filename string) (WgConfig, error) { // Unmarshal the JSON data into the struct err = json.Unmarshal(byteValue, &wgconfig) if err != nil { - fmt.Println("Error unmarshaling JSON:", err) + logger.Error("Error unmarshaling JSON data: %v", err) return WgConfig{}, err } @@ -217,23 +240,23 @@ func ensureWireguardInterface(wgconfig WgConfig) error { // Interface doesn't exist, so create it err = createWireGuardInterface() if err != nil { - log.Fatalf("Failed to create WireGuard interface: %v", err) + logger.Fatal("Failed to create WireGuard interface: %v", err) } - log.Printf("Created WireGuard interface %s\n", interfaceName) + logger.Info("Created WireGuard interface %s\n", interfaceName) } else { - log.Fatalf("Error checking for WireGuard interface: %v", err) + logger.Fatal("Error checking for WireGuard interface: %v", err) } } else { - log.Printf("WireGuard interface %s already exists\n", interfaceName) + logger.Info("WireGuard interface %s already exists\n", interfaceName) return nil } // Assign IP address to the interface err = assignIPAddress(wgconfig.IpAddress) if err != nil { - log.Fatalf("Failed to assign IP address: %v", err) + logger.Fatal("Failed to assign IP address: %v", err) } - log.Printf("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) + logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) // Check if the interface already exists _, err = wgClient.Device(interfaceName) @@ -269,7 +292,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { return fmt.Errorf("failed to bring up interface: %v", err) } - log.Printf("WireGuard interface %s created and configured", interfaceName) + logger.Info("WireGuard interface %s created and configured", interfaceName) return nil } @@ -403,7 +426,7 @@ func addPeer(peer Peer) error { return fmt.Errorf("failed to add peer: %v", err) } - log.Printf("Peer %s added successfully", peer.PublicKey) + logger.Info("Peer %s added successfully", peer.PublicKey) return nil } @@ -444,7 +467,7 @@ func removePeer(publicKey string) error { return fmt.Errorf("failed to remove peer: %v", err) } - log.Printf("Peer %s removed successfully", publicKey) + logger.Info("Peer %s removed successfully", publicKey) return nil } @@ -455,7 +478,7 @@ func periodicBandwidthCheck(endpoint string) { for range ticker.C { if err := reportPeerBandwidth(endpoint); err != nil { - log.Printf("Failed to report peer bandwidth: %v", err) + logger.Info("Failed to report peer bandwidth: %v", err) } } } @@ -546,6 +569,5 @@ func reportPeerBandwidth(apiURL string) error { return fmt.Errorf("API returned non-OK status: %s", resp.Status) } - // log.Println("Bandwidth data sent successfully") return nil }