diff --git a/common.go b/common.go index 5d9ed7b..192d93c 100644 --- a/common.go +++ b/common.go @@ -13,8 +13,8 @@ import ( "strings" "time" - "github.com/fosrl/newt/logger" "github.com/fosrl/newt/websocket" + "github.com/fosrl/olm/logger" "github.com/fosrl/olm/peermonitor" "github.com/vishvananda/netlink" "golang.org/x/crypto/chacha20poly1305" diff --git a/logger/level.go b/logger/level.go new file mode 100644 index 0000000..175995f --- /dev/null +++ b/logger/level.go @@ -0,0 +1,27 @@ +package logger + +type LogLevel int + +const ( + DEBUG LogLevel = iota + INFO + WARN + ERROR + FATAL +) + +var levelStrings = map[LogLevel]string{ + DEBUG: "DEBUG", + INFO: "INFO", + WARN: "WARN", + ERROR: "ERROR", + FATAL: "FATAL", +} + +// String returns the string representation of the log level +func (l LogLevel) String() string { + if s, ok := levelStrings[l]; ok { + return s + } + return "UNKNOWN" +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..28cac91 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,133 @@ +package logger + +import ( + "fmt" + "io" + "log" + "os" + "sync" + "time" +) + +// Logger struct holds the logger instance +type Logger struct { + logger *log.Logger + level LogLevel +} + +var ( + defaultLogger *Logger + once sync.Once +) + +// NewLogger creates a new logger instance +func NewLogger() *Logger { + return &Logger{ + logger: log.New(os.Stdout, "", 0), + level: DEBUG, + } +} + +// Init initializes the default logger +func Init() *Logger { + once.Do(func() { + defaultLogger = NewLogger() + }) + return defaultLogger +} + +// GetLogger returns the default logger instance +func GetLogger() *Logger { + if defaultLogger == nil { + Init() + } + return defaultLogger +} + +// SetLevel sets the minimum logging level +func (l *Logger) SetLevel(level LogLevel) { + l.level = level +} + +// SetOutput sets the output destination for the logger +func (l *Logger) SetOutput(w io.Writer) { + l.logger.SetOutput(w) +} + +// log handles the actual logging +func (l *Logger) log(level LogLevel, format string, args ...interface{}) { + if level < l.level { + return + } + + // Get timezone from environment variable or use local timezone + timezone := os.Getenv("LOGGER_TIMEZONE") + var location *time.Location + var err error + + if timezone != "" { + location, err = time.LoadLocation(timezone) + if err != nil { + // If invalid timezone, fall back to local + location = time.Local + } + } else { + location = time.Local + } + + timestamp := time.Now().In(location).Format("2006/01/02 15:04:05") + message := fmt.Sprintf(format, args...) + l.logger.Printf("%s: %s %s", level.String(), timestamp, message) +} + +// Debug logs debug level messages +func (l *Logger) Debug(format string, args ...interface{}) { + l.log(DEBUG, format, args...) +} + +// Info logs info level messages +func (l *Logger) Info(format string, args ...interface{}) { + l.log(INFO, format, args...) +} + +// Warn logs warning level messages +func (l *Logger) Warn(format string, args ...interface{}) { + l.log(WARN, format, args...) +} + +// Error logs error level messages +func (l *Logger) Error(format string, args ...interface{}) { + l.log(ERROR, format, args...) +} + +// Fatal logs fatal level messages and exits +func (l *Logger) Fatal(format string, args ...interface{}) { + l.log(FATAL, format, args...) + os.Exit(1) +} + +// Global helper functions +func Debug(format string, args ...interface{}) { + GetLogger().Debug(format, args...) +} + +func Info(format string, args ...interface{}) { + GetLogger().Info(format, args...) +} + +func Warn(format string, args ...interface{}) { + GetLogger().Warn(format, args...) +} + +func Error(format string, args ...interface{}) { + GetLogger().Error(format, args...) +} + +func Fatal(format string, args ...interface{}) { + GetLogger().Fatal(format, args...) +} + +// SetOutput sets the output destination for the default logger +func SetOutput(w io.Writer) { + GetLogger().SetOutput(w) +} diff --git a/main.go b/main.go index bc25004..45e1303 100644 --- a/main.go +++ b/main.go @@ -13,9 +13,9 @@ import ( "syscall" "time" - "github.com/fosrl/newt/logger" "github.com/fosrl/newt/websocket" "github.com/fosrl/olm/httpserver" + "github.com/fosrl/olm/logger" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/wgtester" @@ -28,7 +28,7 @@ import ( func main() { // Check if we're running as a Windows service if isWindowsService() { - runService("OlmWireguardService", false) + runService("OlmWireguardService", false, os.Args[1:]) fmt.Println("Running as Windows service") return } @@ -77,7 +77,11 @@ func main() { fmt.Printf("Service status: %s\n", status) return case "debug": - runService("OlmWireguardService", true) + err := debugService() + if err != nil { + fmt.Printf("Failed to debug service: %v\n", err) + os.Exit(1) + } return case "help", "--help", "-h": fmt.Println("Olm WireGuard VPN Client") @@ -102,13 +106,15 @@ func main() { } func runOlmMain(ctx context.Context) { - // Log that we've entered the main function - fmt.Printf("runOlmMain() called - starting main logic\n") + runOlmMainWithArgs(ctx, os.Args[1:]) +} - // Setup Windows event logging if on Windows - if runtime.GOOS == "windows" { - setupWindowsEventLog() - } +func runOlmMainWithArgs(ctx context.Context, args []string) { + // Log that we've entered the main function + fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args) + + // Create a new FlagSet for parsing service arguments + serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError) var ( endpoint string @@ -146,39 +152,63 @@ func runOlmMain(ctx context.Context) { pingTimeoutStr := os.Getenv("PING_TIMEOUT") // Debug: Print all environment variables we're checking - fmt.Printf("Environment variables: PANGOLIN_ENDPOINT='%s', OLM_ID='%s', OLM_SECRET='%s'\n", endpoint, id, secret) + // fmt.Printf("Environment variables: PANGOLIN_ENDPOINT='%s', OLM_ID='%s', OLM_SECRET='%s'\n", endpoint, id, secret) + // Setup flags for service mode + // serviceFlags.StringVar(&endpoint, "endpoint", endpoint, "Endpoint of your Pangolin server") + // serviceFlags.StringVar(&id, "id", id, "Olm ID") + // serviceFlags.StringVar(&secret, "secret", secret, "Olm secret") + // serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use") + // serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") + // serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + // serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") + // serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')") + // serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") + // serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", "Timeout for each ping (default 5s)") + // serviceFlags.BoolVar(&enableHTTP, "http", false, "Enable HTTP server") + // serviceFlags.BoolVar(&testMode, "test", false, "Test WireGuard connectivity to a target") + // serviceFlags.StringVar(&testTarget, "test-target", "", "Target server:port for test mode") if endpoint == "" { - flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") + serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") } if id == "" { - flag.StringVar(&id, "id", "", "Olm ID") + serviceFlags.StringVar(&id, "id", "", "Olm ID") } if secret == "" { - flag.StringVar(&secret, "secret", "", "Olm secret") + serviceFlags.StringVar(&secret, "secret", "", "Olm secret") } if mtu == "" { - flag.StringVar(&mtu, "mtu", "1280", "MTU to use") + serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use") } if dns == "" { - flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") + serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") } if logLevel == "" { - flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") + serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") } if httpAddr == "" { - flag.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')") + serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')") } if pingIntervalStr == "" { - flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") + serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") } if pingTimeoutStr == "" { - flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") + serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") } + // Parse the service arguments + if err := serviceFlags.Parse(args); err != nil { + fmt.Printf("Error parsing service arguments: %v\n", err) + return + } + + // Debug: Print final values after flag parsing + fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret) + + // Parse ping intervals if pingIntervalStr != "" { pingInterval, err = time.ParseDuration(pingIntervalStr) if err != nil { @@ -199,24 +229,13 @@ func runOlmMain(ctx context.Context) { pingTimeout = 5 * time.Second } - flag.BoolVar(&enableHTTP, "http", false, "Enable HTTP server") - flag.BoolVar(&testMode, "test", false, "Test WireGuard connectivity to a target") - flag.StringVar(&testTarget, "test-target", "", "Target server:port for test mode") - - // do a --version check - version := flag.Bool("version", false, "Print the version") - - flag.Parse() - - // Debug: Print final values after flag parsing - fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret) - - if *version { - fmt.Println("Olm version replaceme") - os.Exit(0) + // Setup Windows event logging if on Windows + if runtime.GOOS == "windows" { + setupWindowsEventLog() + } else { + // Initialize logger for non-Windows platforms + logger.Init() } - - logger.Init() loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) diff --git a/service_unix.go b/service_unix.go index beeaef1..c616f78 100644 --- a/service_unix.go +++ b/service_unix.go @@ -27,11 +27,15 @@ func getServiceStatus() (string, error) { return "", fmt.Errorf("service management is only available on Windows") } +func debugService() error { + return fmt.Errorf("debug service is only available on Windows") +} + func isWindowsService() bool { return false } -func runService(name string, isDebug bool) { +func runService(name string, isDebug bool, args []string) { // No-op on non-Windows platforms } diff --git a/service_windows.go b/service_windows.go index 0cbc7bd..a0cffc7 100644 --- a/service_windows.go +++ b/service_windows.go @@ -5,11 +5,15 @@ package main import ( "context" "fmt" + "io" "log" "os" + "os/signal" "path/filepath" + "syscall" "time" + "github.com/fosrl/olm/logger" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/debug" "golang.org/x/sys/windows/svc/eventlog" @@ -22,10 +26,14 @@ const ( serviceDescription = "Olm WireGuard VPN client service for secure network connectivity" ) +// Global variable to store service arguments +var serviceArgs []string + type olmService struct { elog debug.Log ctx context.Context stop context.CancelFunc + args []string } func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { @@ -80,7 +88,6 @@ func (s *olmService) runOlm() { s.ctx, s.stop = context.WithCancel(context.Background()) // Setup logging for service mode - setupWindowsEventLog() s.elog.Info(1, "Starting Olm main logic") // Run the main olm logic and wait for it to complete @@ -93,8 +100,8 @@ func (s *olmService) runOlm() { close(done) }() - // Call the main olm function - runOlmMain(s.ctx) + // Call the main olm function with stored arguments + runOlmMainWithArgs(s.ctx, s.args) }() // Wait for either context cancellation or main logic completion @@ -106,7 +113,7 @@ func (s *olmService) runOlm() { } } -func runService(name string, isDebug bool) { +func runService(name string, isDebug bool, args []string) { var err error var elog debug.Log @@ -128,7 +135,7 @@ func runService(name string, isDebug bool) { run = debug.Run } - service := &olmService{elog: elog} + service := &olmService{elog: elog, args: args} err = run(name, service) if err != nil { elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err)) @@ -291,6 +298,76 @@ func stopService() error { return nil } +func debugService() error { + // Get the log file path + logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "Olm", "logs") + logFile := filepath.Join(logDir, "olm.log") + + fmt.Printf("Starting service in debug mode...\n") + fmt.Printf("Log file: %s\n", logFile) + + // Start the service + err := startService() + if err != nil { + return fmt.Errorf("failed to start service: %v", err) + } + + fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n") + fmt.Printf("================================================================================\n") + + // Watch the log file + return watchLogFile(logFile) +} + +func watchLogFile(logPath string) error { + // Open the log file + file, err := os.Open(logPath) + if err != nil { + return fmt.Errorf("failed to open log file: %v", err) + } + defer file.Close() + + // Seek to the end of the file to only show new logs + _, err = file.Seek(0, 2) + if err != nil { + return fmt.Errorf("failed to seek to end of file: %v", err) + } + + // Set up signal handling for graceful exit + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + // Create a ticker to check for new content + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + buffer := make([]byte, 4096) + + for { + select { + case <-sigCh: + fmt.Printf("\n\nStopping log watch...\n") + // stop the service if needed + if err := stopService(); err != nil { + fmt.Printf("Failed to stop service: %v\n", err) + } + fmt.Printf("Log watch stopped.\n") + return nil + case <-ticker.C: + // Read new content + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + return fmt.Errorf("error reading log file: %v", err) + } + + if n > 0 { + // Print the new content + fmt.Print(string(buffer[:n])) + } + } + } +} + func getServiceStatus() (string, error) { m, err := mgr.Connect() if err != nil { @@ -349,6 +426,9 @@ func setupWindowsEventLog() { fmt.Printf("Failed to open log file: %v\n", err) return } - log.SetOutput(file) + + // Set the custom logger output + logger.GetLogger().SetOutput(file) + log.Printf("Olm service logging initialized - log file: %s", logFile) }