diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 791d63364..ae72a50f2 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -139,7 +139,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec group, err := s.keyStore.GetGroupByName(ctx, rp.Name, rp.AccountID) if err != nil { - // TODO: log this? + log.WithFields(log.Fields{ + "proxy": rp.Name, + "account": rp.AccountID, + }).WithError(err).Error("Failed to get group by name") continue } @@ -156,7 +159,11 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec false, ) if err != nil { - // TODO: how to handle this? + log.WithFields(log.Fields{ + "proxy": rp.Name, + "account": rp.AccountID, + "group": group.ID, + }).WithError(err).Error("Failed to create setup key") continue } @@ -168,7 +175,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec ), }, }); err != nil { - // TODO: log the error, maybe retry? + log.WithError(err).Error("Failed to send proxy mapping") continue } } diff --git a/proxy/cmd/proxy/main.go b/proxy/cmd/proxy/main.go index bd55ae98d..244b6b8f2 100644 --- a/proxy/cmd/proxy/main.go +++ b/proxy/cmd/proxy/main.go @@ -74,12 +74,14 @@ func main() { if debug { level = "debug" } + logger := log.New() - _ = util.InitLog(level, util.LogConsole) + _ = util.InitLogger(logger, level, util.LogConsole) log.Infof("configured log level: %s", level) srv := proxy.Server{ + Logger: logger, Version: Version, ManagementAddress: mgmtAddr, ProxyURL: url, diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 64bf47cd1..b23f79b58 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -16,11 +16,16 @@ type gRPCClient interface { type Logger struct { client gRPCClient + logger *log.Logger } -func NewLogger(client gRPCClient) *Logger { +func NewLogger(client gRPCClient, logger *log.Logger) *Logger { + if logger == nil { + logger = log.StandardLogger() + } return &Logger{ client: client, + logger: logger, } } @@ -68,7 +73,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { }, }); err != nil { // If it fails to send on the gRPC connection, then at least log it to the error log. - log.WithFields(log.Fields{ + l.logger.WithFields(log.Fields{ "service_id": entry.ServiceId, "host": entry.Host, "path": entry.Path, diff --git a/proxy/internal/auth/oidc.go b/proxy/internal/auth/oidc.go index 9c3787a19..dacc55cb8 100644 --- a/proxy/internal/auth/oidc.go +++ b/proxy/internal/auth/oidc.go @@ -12,7 +12,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - log "github.com/sirupsen/logrus" "golang.org/x/oauth2" ) @@ -165,7 +164,6 @@ func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) { // Exchange code for token token, err := o.oauthConfig.Exchange(r.Context(), code) if err != nil { - log.WithError(err).Error("Token exchange failed") http.Error(w, "Authentication failed", http.StatusUnauthorized) return } diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index e01501dc3..3ce5608c9 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -2,6 +2,7 @@ package roundtrip import ( "context" + "errors" "fmt" "io" "net/http" @@ -19,14 +20,19 @@ const deviceNamePrefix = "ingress-" // backed by underlying NetBird connections. type NetBird struct { mgmtAddr string + logger *log.Logger clientsMux sync.RWMutex clients map[string]*embed.Client } -func NewNetBird(mgmtAddr string) *NetBird { +func NewNetBird(mgmtAddr string, logger *log.Logger) *NetBird { + if logger == nil { + logger = log.StandardLogger() + } return &NetBird{ mgmtAddr: mgmtAddr, + logger: logger, clients: make(map[string]*embed.Client), } } @@ -37,13 +43,29 @@ func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error { ManagementURL: n.mgmtAddr, SetupKey: key, LogOutput: io.Discard, + BlockInbound: true, }) if err != nil { return fmt.Errorf("create netbird client: %w", err) } - if err := client.Start(ctx); err != nil { - return fmt.Errorf("start netbird client: %w", err) - } + + // Attempt to start the client in the background, if this fails + // then it is not ideal, but it isn't the end of the world because + // we will try to start the client again before we use it. + go func() { + startCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + err = client.Start(startCtx) + switch { + case errors.Is(err, context.DeadlineExceeded): + n.logger.Debug("netbird client timed out") + // This is not ideal, but we will try again later. + return + case err != nil: + n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.") + } + }() + n.clientsMux.Lock() defer n.clientsMux.Unlock() n.clients[domain] = client @@ -78,7 +100,20 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("no peer connection found for host: %s", req.Host) } - log.WithFields(log.Fields{ + // Attempt to start the client, if the client is already running then + // it will return an error that we ignore, if this hits a timeout then + // this request is unprocessable. + startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second) + defer cancel() + err := client.Start(startCtx) + switch { + case errors.Is(err, embed.ErrClientAlreadyStarted): + break + case err != nil: + return nil, fmt.Errorf("start netbird client: %w", err) + } + + n.logger.WithFields(log.Fields{ "host": req.Host, "url": req.URL.String(), "requestURI": req.RequestURI, diff --git a/proxy/server.go b/proxy/server.go index 6009505ed..2a6e71bfb 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -50,6 +50,7 @@ type Server struct { startTime time.Time ID string + Logger *log.Logger Version string ProxyURL string ManagementAddress string @@ -71,6 +72,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.Version = "dev" } + // If no logger is specified fallback to the standard logger. + if s.Logger == nil { + s.Logger = log.StandardLogger() + } + // The very first thing to do should be to connect to the Management server. // Without this connection, the Proxy cannot do anything. mgmtURL, err := url.Parse(s.ManagementAddress) @@ -91,7 +97,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { RootCAs: certPool, }) } - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "gRPC_address": mgmtURL.Host, "TLS_enabled": mgmtURL.Scheme == "https", }).Debug("starting management gRPC client") @@ -111,12 +117,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Initialize the netbird client, this is required to build peer connections // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ManagementAddress) + s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger) // When generating ACME certificates, start a challenge server. tlsConfig := &tls.Config{} if s.GenerateACMECertificates { - log.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager") + s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager") s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory) s.http = &http.Server{ Addr: s.ACMEChallengeAddress, @@ -126,7 +132,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { if err := s.http.ListenAndServe(); err != nil { // Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue. for range time.Tick(10 * time.Second) { - log.WithError(err).Error("ACME HTTP-01 challenge server error") + s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error") } } }() @@ -142,11 +148,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // when using CNAME URLs to access the proxy. tlsConfig.ServerName = s.ProxyURL - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "ServerName": s.ProxyURL, }).Debug("started ACME challenge server") } else { - log.Debug("ACME certificates disabled, using static certificates") + s.Logger.Debug("ACME certificates disabled, using static certificates") // Otherwise pull some certificates from expected locations. cert, err := tls.LoadX509KeyPair( filepath.Join(s.CertificateDirectory, "tls.crt"), @@ -165,7 +171,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.auth = auth.NewMiddleware() // Configure Access logs to management server. - accessLog := accesslog.NewLogger(mgmtClient) + accessLog := accesslog.NewLogger(mgmtClient, s.Logger) // Finally, start the reverse proxy. s.https = &http.Server{ @@ -179,23 +185,23 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { b := backoff.New(0, 0) for { - log.Debug("Getting mapping updates from management server") + s.Logger.Debug("Getting mapping updates from management server") mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ ProxyId: s.ID, Version: s.Version, StartedAt: timestamppb.New(s.startTime), }) if err != nil { - log.WithError(err).Warn("Could not get mapping updates, will retry") + s.Logger.WithError(err).Warn("Could not get mapping updates, will retry") backoffDuration := b.Duration() - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "backoff": backoffDuration, "error": err, }).Error("Unable to create mapping client to management server, retrying connection after backoff") time.Sleep(backoffDuration) continue } - log.Debug("Got mapping updates client from management server") + s.Logger.Debug("Got mapping updates client from management server") err = s.handleMappingStream(ctx, mappingClient) backoffDuration := b.Duration() switch { @@ -203,17 +209,17 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr errors.Is(err, context.DeadlineExceeded): // Context is telling us that it is time to quit so gracefully exit here. // No need to log the error as it is a parent context causing this return. - log.Debugf("Got context error, will exit loop: %v", err) + s.Logger.Debugf("Got context error, will exit loop: %v", err) return case err != nil: // Log the error and then retry the connection. - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "backoff": backoffDuration, "error": err, }).Error("Error processing mapping stream from management server, retrying connection after backoff") default: // TODO: should this really be at error level? Maybe, if you start getting lots of these this could be an indication of connectivity issues. - log.WithField("backoff", backoffDuration).Error("Management mapping connection terminated by the server, retrying connection after backoff") + s.Logger.WithField("backoff", backoffDuration).Error("Management mapping connection terminated by the server, retrying connection after backoff") } time.Sleep(backoffDuration) } @@ -236,11 +242,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr // Something has gone horribly wrong, return and hope the parent retries the connection. return fmt.Errorf("receive msg: %w", err) } - log.Debug("Received mapping update, starting processing") + s.Logger.Debug("Received mapping update, starting processing") // Process msg updates sequentially to avoid conflict, so block // additional receiving until this processing is completed. for _, mapping := range msg.GetMapping() { - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), "domain": mapping.GetDomain(), "path": mapping.GetPath(), @@ -250,7 +256,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { // TODO: Retry this? Or maybe notify the management server that this mapping has failed? - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, @@ -262,6 +268,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr s.removeMapping(ctx, mapping) } } + s.Logger.Debug("Processing mapping update completed") } } } @@ -305,7 +312,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) OIDCScopes: oidc.GetOidcScopes(), }) if err != nil { - log.WithError(err).Error("Failed to create OIDC scheme") + s.Logger.WithError(err).Error("Failed to create OIDC scheme") } else { schemes = append(schemes, scheme) } @@ -319,7 +326,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil { - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "domain": mapping.GetDomain(), "error": err, }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") @@ -337,7 +344,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { targetURL, err := url.Parse(pathMapping.GetTarget()) if err != nil { // TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure? - log.WithFields(log.Fields{ + s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "account_id": mapping.GetAccountId(), "domain": mapping.GetDomain(), diff --git a/util/log.go b/util/log.go index a951eab87..03547024a 100644 --- a/util/log.go +++ b/util/log.go @@ -30,9 +30,14 @@ var ( // InitLog parses and sets log-level input func InitLog(logLevel string, logs ...string) error { + return InitLogger(log.StandardLogger(), logLevel, logs...) +} + +// InitLogger parses and sets log-level input for a logrus logger +func InitLogger(logger *log.Logger, logLevel string, logs ...string) error { level, err := log.ParseLevel(logLevel) if err != nil { - log.Errorf("Failed parsing log-level %s: %s", logLevel, err) + logger.Errorf("Failed parsing log-level %s: %s", logLevel, err) return err } var writers []io.Writer @@ -41,34 +46,34 @@ func InitLog(logLevel string, logs ...string) error { for _, logPath := range logs { switch logPath { case LogSyslog: - AddSyslogHook() + AddSyslogHookToLogger(logger) logFmt = "syslog" case LogConsole: writers = append(writers, os.Stderr) case "": - log.Warnf("empty log path received: %#v", logPath) + logger.Warnf("empty log path received: %#v", logPath) default: writers = append(writers, newRotatedOutput(logPath)) } } if len(writers) > 1 { - log.SetOutput(io.MultiWriter(writers...)) + logger.SetOutput(io.MultiWriter(writers...)) } else if len(writers) == 1 { - log.SetOutput(writers[0]) + logger.SetOutput(writers[0]) } switch logFmt { case "json": - formatter.SetJSONFormatter(log.StandardLogger()) + formatter.SetJSONFormatter(logger) case "syslog": - formatter.SetSyslogFormatter(log.StandardLogger()) + formatter.SetSyslogFormatter(logger) default: - formatter.SetTextFormatter(log.StandardLogger()) + formatter.SetTextFormatter(logger) } - log.SetLevel(level) + logger.SetLevel(level) - setGRPCLibLogger() + setGRPCLibLogger(logger) return nil } @@ -96,8 +101,8 @@ func newRotatedOutput(logPath string) io.Writer { return lumberjackLogger } -func setGRPCLibLogger() { - logOut := log.StandardLogger().Writer() +func setGRPCLibLogger(logger *log.Logger) { + logOut := logger.Writer() if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" { grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, logOut, logOut)) return diff --git a/util/syslog_nonwindows.go b/util/syslog_nonwindows.go index 328bb8b1c..4a33f21b1 100644 --- a/util/syslog_nonwindows.go +++ b/util/syslog_nonwindows.go @@ -10,10 +10,14 @@ import ( ) func AddSyslogHook() { + AddSyslogHookToLogger(log.StandardLogger()) +} + +func AddSyslogHookToLogger(logger *log.Logger) { hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "") if err != nil { - log.Errorf("Failed creating syslog hook: %s", err) + logger.Errorf("Failed creating syslog hook: %s", err) } - log.AddHook(hook) + logger.AddHook(hook) } diff --git a/util/syslog_windows.go b/util/syslog_windows.go index 171c1a459..a4120ef66 100644 --- a/util/syslog_windows.go +++ b/util/syslog_windows.go @@ -4,3 +4,8 @@ func AddSyslogHook() { // The syslog package is not available for Windows. This adapter is needed // to handle windows build. } + +func AddSyslogHookToLogger(logger *log.Logger) { + // The syslog package is not available for Windows. This adapter is needed + // to handle windows build. +}