use a defined logger

this should avoid issues with the embedded
client also attempting to use the same global logger
This commit is contained in:
Alisdair MacLeod
2026-01-30 16:31:32 +00:00
parent 5345d716ee
commit 3a6f364b03
9 changed files with 115 additions and 47 deletions

View File

@@ -139,7 +139,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
group, err := s.keyStore.GetGroupByName(ctx, rp.Name, rp.AccountID) group, err := s.keyStore.GetGroupByName(ctx, rp.Name, rp.AccountID)
if err != nil { if err != nil {
// TODO: log this? log.WithFields(log.Fields{
"proxy": rp.Name,
"account": rp.AccountID,
}).WithError(err).Error("Failed to get group by name")
continue continue
} }
@@ -156,7 +159,11 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
false, false,
) )
if err != nil { if err != nil {
// TODO: how to handle this? log.WithFields(log.Fields{
"proxy": rp.Name,
"account": rp.AccountID,
"group": group.ID,
}).WithError(err).Error("Failed to create setup key")
continue continue
} }
@@ -168,7 +175,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
), ),
}, },
}); err != nil { }); err != nil {
// TODO: log the error, maybe retry? log.WithError(err).Error("Failed to send proxy mapping")
continue continue
} }
} }

View File

@@ -74,12 +74,14 @@ func main() {
if debug { if debug {
level = "debug" level = "debug"
} }
logger := log.New()
_ = util.InitLog(level, util.LogConsole) _ = util.InitLogger(logger, level, util.LogConsole)
log.Infof("configured log level: %s", level) log.Infof("configured log level: %s", level)
srv := proxy.Server{ srv := proxy.Server{
Logger: logger,
Version: Version, Version: Version,
ManagementAddress: mgmtAddr, ManagementAddress: mgmtAddr,
ProxyURL: url, ProxyURL: url,

View File

@@ -16,11 +16,16 @@ type gRPCClient interface {
type Logger struct { type Logger struct {
client gRPCClient client gRPCClient
logger *log.Logger
} }
func NewLogger(client gRPCClient) *Logger { func NewLogger(client gRPCClient, logger *log.Logger) *Logger {
if logger == nil {
logger = log.StandardLogger()
}
return &Logger{ return &Logger{
client: client, client: client,
logger: logger,
} }
} }
@@ -68,7 +73,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
}, },
}); err != nil { }); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log. // If it fails to send on the gRPC connection, then at least log it to the error log.
log.WithFields(log.Fields{ l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId, "service_id": entry.ServiceId,
"host": entry.Host, "host": entry.Host,
"path": entry.Path, "path": entry.Path,

View File

@@ -12,7 +12,6 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@@ -165,7 +164,6 @@ func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
// Exchange code for token // Exchange code for token
token, err := o.oauthConfig.Exchange(r.Context(), code) token, err := o.oauthConfig.Exchange(r.Context(), code)
if err != nil { if err != nil {
log.WithError(err).Error("Token exchange failed")
http.Error(w, "Authentication failed", http.StatusUnauthorized) http.Error(w, "Authentication failed", http.StatusUnauthorized)
return return
} }

View File

@@ -2,6 +2,7 @@ package roundtrip
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -19,14 +20,19 @@ const deviceNamePrefix = "ingress-"
// backed by underlying NetBird connections. // backed by underlying NetBird connections.
type NetBird struct { type NetBird struct {
mgmtAddr string mgmtAddr string
logger *log.Logger
clientsMux sync.RWMutex clientsMux sync.RWMutex
clients map[string]*embed.Client clients map[string]*embed.Client
} }
func NewNetBird(mgmtAddr string) *NetBird { func NewNetBird(mgmtAddr string, logger *log.Logger) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{ return &NetBird{
mgmtAddr: mgmtAddr, mgmtAddr: mgmtAddr,
logger: logger,
clients: make(map[string]*embed.Client), clients: make(map[string]*embed.Client),
} }
} }
@@ -37,13 +43,29 @@ func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error {
ManagementURL: n.mgmtAddr, ManagementURL: n.mgmtAddr,
SetupKey: key, SetupKey: key,
LogOutput: io.Discard, LogOutput: io.Discard,
BlockInbound: true,
}) })
if err != nil { if err != nil {
return fmt.Errorf("create netbird client: %w", err) return fmt.Errorf("create netbird client: %w", err)
} }
if err := client.Start(ctx); err != nil {
return fmt.Errorf("start netbird client: %w", err) // Attempt to start the client in the background, if this fails
} // then it is not ideal, but it isn't the end of the world because
// we will try to start the client again before we use it.
go func() {
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err = client.Start(startCtx)
switch {
case errors.Is(err, context.DeadlineExceeded):
n.logger.Debug("netbird client timed out")
// This is not ideal, but we will try again later.
return
case err != nil:
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
}
}()
n.clientsMux.Lock() n.clientsMux.Lock()
defer n.clientsMux.Unlock() defer n.clientsMux.Unlock()
n.clients[domain] = client n.clients[domain] = client
@@ -78,7 +100,20 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("no peer connection found for host: %s", req.Host) return nil, fmt.Errorf("no peer connection found for host: %s", req.Host)
} }
log.WithFields(log.Fields{ // Attempt to start the client, if the client is already running then
// it will return an error that we ignore, if this hits a timeout then
// this request is unprocessable.
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
defer cancel()
err := client.Start(startCtx)
switch {
case errors.Is(err, embed.ErrClientAlreadyStarted):
break
case err != nil:
return nil, fmt.Errorf("start netbird client: %w", err)
}
n.logger.WithFields(log.Fields{
"host": req.Host, "host": req.Host,
"url": req.URL.String(), "url": req.URL.String(),
"requestURI": req.RequestURI, "requestURI": req.RequestURI,

View File

@@ -50,6 +50,7 @@ type Server struct {
startTime time.Time startTime time.Time
ID string ID string
Logger *log.Logger
Version string Version string
ProxyURL string ProxyURL string
ManagementAddress string ManagementAddress string
@@ -71,6 +72,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.Version = "dev" s.Version = "dev"
} }
// If no logger is specified fallback to the standard logger.
if s.Logger == nil {
s.Logger = log.StandardLogger()
}
// The very first thing to do should be to connect to the Management server. // The very first thing to do should be to connect to the Management server.
// Without this connection, the Proxy cannot do anything. // Without this connection, the Proxy cannot do anything.
mgmtURL, err := url.Parse(s.ManagementAddress) mgmtURL, err := url.Parse(s.ManagementAddress)
@@ -91,7 +97,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
RootCAs: certPool, RootCAs: certPool,
}) })
} }
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"gRPC_address": mgmtURL.Host, "gRPC_address": mgmtURL.Host,
"TLS_enabled": mgmtURL.Scheme == "https", "TLS_enabled": mgmtURL.Scheme == "https",
}).Debug("starting management gRPC client") }).Debug("starting management gRPC client")
@@ -111,12 +117,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Initialize the netbird client, this is required to build peer connections // Initialize the netbird client, this is required to build peer connections
// to proxy over. // to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress) s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger)
// When generating ACME certificates, start a challenge server. // When generating ACME certificates, start a challenge server.
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
if s.GenerateACMECertificates { if s.GenerateACMECertificates {
log.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager") s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory) s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory)
s.http = &http.Server{ s.http = &http.Server{
Addr: s.ACMEChallengeAddress, Addr: s.ACMEChallengeAddress,
@@ -126,7 +132,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
if err := s.http.ListenAndServe(); err != nil { if err := s.http.ListenAndServe(); err != nil {
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue. // Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
for range time.Tick(10 * time.Second) { for range time.Tick(10 * time.Second) {
log.WithError(err).Error("ACME HTTP-01 challenge server error") s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error")
} }
} }
}() }()
@@ -142,11 +148,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// when using CNAME URLs to access the proxy. // when using CNAME URLs to access the proxy.
tlsConfig.ServerName = s.ProxyURL tlsConfig.ServerName = s.ProxyURL
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"ServerName": s.ProxyURL, "ServerName": s.ProxyURL,
}).Debug("started ACME challenge server") }).Debug("started ACME challenge server")
} else { } else {
log.Debug("ACME certificates disabled, using static certificates") s.Logger.Debug("ACME certificates disabled, using static certificates")
// Otherwise pull some certificates from expected locations. // Otherwise pull some certificates from expected locations.
cert, err := tls.LoadX509KeyPair( cert, err := tls.LoadX509KeyPair(
filepath.Join(s.CertificateDirectory, "tls.crt"), filepath.Join(s.CertificateDirectory, "tls.crt"),
@@ -165,7 +171,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.auth = auth.NewMiddleware() s.auth = auth.NewMiddleware()
// Configure Access logs to management server. // Configure Access logs to management server.
accessLog := accesslog.NewLogger(mgmtClient) accessLog := accesslog.NewLogger(mgmtClient, s.Logger)
// Finally, start the reverse proxy. // Finally, start the reverse proxy.
s.https = &http.Server{ s.https = &http.Server{
@@ -179,23 +185,23 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {
b := backoff.New(0, 0) b := backoff.New(0, 0)
for { for {
log.Debug("Getting mapping updates from management server") s.Logger.Debug("Getting mapping updates from management server")
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID, ProxyId: s.ID,
Version: s.Version, Version: s.Version,
StartedAt: timestamppb.New(s.startTime), StartedAt: timestamppb.New(s.startTime),
}) })
if err != nil { if err != nil {
log.WithError(err).Warn("Could not get mapping updates, will retry") s.Logger.WithError(err).Warn("Could not get mapping updates, will retry")
backoffDuration := b.Duration() backoffDuration := b.Duration()
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"backoff": backoffDuration, "backoff": backoffDuration,
"error": err, "error": err,
}).Error("Unable to create mapping client to management server, retrying connection after backoff") }).Error("Unable to create mapping client to management server, retrying connection after backoff")
time.Sleep(backoffDuration) time.Sleep(backoffDuration)
continue continue
} }
log.Debug("Got mapping updates client from management server") s.Logger.Debug("Got mapping updates client from management server")
err = s.handleMappingStream(ctx, mappingClient) err = s.handleMappingStream(ctx, mappingClient)
backoffDuration := b.Duration() backoffDuration := b.Duration()
switch { switch {
@@ -203,17 +209,17 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
errors.Is(err, context.DeadlineExceeded): errors.Is(err, context.DeadlineExceeded):
// Context is telling us that it is time to quit so gracefully exit here. // Context is telling us that it is time to quit so gracefully exit here.
// No need to log the error as it is a parent context causing this return. // No need to log the error as it is a parent context causing this return.
log.Debugf("Got context error, will exit loop: %v", err) s.Logger.Debugf("Got context error, will exit loop: %v", err)
return return
case err != nil: case err != nil:
// Log the error and then retry the connection. // Log the error and then retry the connection.
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"backoff": backoffDuration, "backoff": backoffDuration,
"error": err, "error": err,
}).Error("Error processing mapping stream from management server, retrying connection after backoff") }).Error("Error processing mapping stream from management server, retrying connection after backoff")
default: default:
// TODO: should this really be at error level? Maybe, if you start getting lots of these this could be an indication of connectivity issues. // TODO: should this really be at error level? Maybe, if you start getting lots of these this could be an indication of connectivity issues.
log.WithField("backoff", backoffDuration).Error("Management mapping connection terminated by the server, retrying connection after backoff") s.Logger.WithField("backoff", backoffDuration).Error("Management mapping connection terminated by the server, retrying connection after backoff")
} }
time.Sleep(backoffDuration) time.Sleep(backoffDuration)
} }
@@ -236,11 +242,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
// Something has gone horribly wrong, return and hope the parent retries the connection. // Something has gone horribly wrong, return and hope the parent retries the connection.
return fmt.Errorf("receive msg: %w", err) return fmt.Errorf("receive msg: %w", err)
} }
log.Debug("Received mapping update, starting processing") s.Logger.Debug("Received mapping update, starting processing")
// Process msg updates sequentially to avoid conflict, so block // Process msg updates sequentially to avoid conflict, so block
// additional receiving until this processing is completed. // additional receiving until this processing is completed.
for _, mapping := range msg.GetMapping() { for _, mapping := range msg.GetMapping() {
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"type": mapping.GetType(), "type": mapping.GetType(),
"domain": mapping.GetDomain(), "domain": mapping.GetDomain(),
"path": mapping.GetPath(), "path": mapping.GetPath(),
@@ -250,7 +256,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil { if err := s.addMapping(ctx, mapping); err != nil {
// TODO: Retry this? Or maybe notify the management server that this mapping has failed? // TODO: Retry this? Or maybe notify the management server that this mapping has failed?
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(), "service_id": mapping.GetId(),
"domain": mapping.GetDomain(), "domain": mapping.GetDomain(),
"error": err, "error": err,
@@ -262,6 +268,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.removeMapping(ctx, mapping) s.removeMapping(ctx, mapping)
} }
} }
s.Logger.Debug("Processing mapping update completed")
} }
} }
} }
@@ -305,7 +312,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
OIDCScopes: oidc.GetOidcScopes(), OIDCScopes: oidc.GetOidcScopes(),
}) })
if err != nil { if err != nil {
log.WithError(err).Error("Failed to create OIDC scheme") s.Logger.WithError(err).Error("Failed to create OIDC scheme")
} else { } else {
schemes = append(schemes, scheme) schemes = append(schemes, scheme)
} }
@@ -319,7 +326,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil { if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil {
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"domain": mapping.GetDomain(), "domain": mapping.GetDomain(),
"error": err, "error": err,
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
@@ -337,7 +344,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
targetURL, err := url.Parse(pathMapping.GetTarget()) targetURL, err := url.Parse(pathMapping.GetTarget())
if err != nil { if err != nil {
// TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure? // TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure?
log.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(), "service_id": mapping.GetId(),
"account_id": mapping.GetAccountId(), "account_id": mapping.GetAccountId(),
"domain": mapping.GetDomain(), "domain": mapping.GetDomain(),

View File

@@ -30,9 +30,14 @@ var (
// InitLog parses and sets log-level input // InitLog parses and sets log-level input
func InitLog(logLevel string, logs ...string) error { func InitLog(logLevel string, logs ...string) error {
return InitLogger(log.StandardLogger(), logLevel, logs...)
}
// InitLogger parses and sets log-level input for a logrus logger
func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
level, err := log.ParseLevel(logLevel) level, err := log.ParseLevel(logLevel)
if err != nil { if err != nil {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err) logger.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err return err
} }
var writers []io.Writer var writers []io.Writer
@@ -41,34 +46,34 @@ func InitLog(logLevel string, logs ...string) error {
for _, logPath := range logs { for _, logPath := range logs {
switch logPath { switch logPath {
case LogSyslog: case LogSyslog:
AddSyslogHook() AddSyslogHookToLogger(logger)
logFmt = "syslog" logFmt = "syslog"
case LogConsole: case LogConsole:
writers = append(writers, os.Stderr) writers = append(writers, os.Stderr)
case "": case "":
log.Warnf("empty log path received: %#v", logPath) logger.Warnf("empty log path received: %#v", logPath)
default: default:
writers = append(writers, newRotatedOutput(logPath)) writers = append(writers, newRotatedOutput(logPath))
} }
} }
if len(writers) > 1 { if len(writers) > 1 {
log.SetOutput(io.MultiWriter(writers...)) logger.SetOutput(io.MultiWriter(writers...))
} else if len(writers) == 1 { } else if len(writers) == 1 {
log.SetOutput(writers[0]) logger.SetOutput(writers[0])
} }
switch logFmt { switch logFmt {
case "json": case "json":
formatter.SetJSONFormatter(log.StandardLogger()) formatter.SetJSONFormatter(logger)
case "syslog": case "syslog":
formatter.SetSyslogFormatter(log.StandardLogger()) formatter.SetSyslogFormatter(logger)
default: default:
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(logger)
} }
log.SetLevel(level) logger.SetLevel(level)
setGRPCLibLogger() setGRPCLibLogger(logger)
return nil return nil
} }
@@ -96,8 +101,8 @@ func newRotatedOutput(logPath string) io.Writer {
return lumberjackLogger return lumberjackLogger
} }
func setGRPCLibLogger() { func setGRPCLibLogger(logger *log.Logger) {
logOut := log.StandardLogger().Writer() logOut := logger.Writer()
if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" { if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" {
grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, logOut, logOut)) grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, logOut, logOut))
return return

View File

@@ -10,10 +10,14 @@ import (
) )
func AddSyslogHook() { func AddSyslogHook() {
AddSyslogHookToLogger(log.StandardLogger())
}
func AddSyslogHookToLogger(logger *log.Logger) {
hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "") hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "")
if err != nil { if err != nil {
log.Errorf("Failed creating syslog hook: %s", err) logger.Errorf("Failed creating syslog hook: %s", err)
} }
log.AddHook(hook) logger.AddHook(hook)
} }

View File

@@ -4,3 +4,8 @@ func AddSyslogHook() {
// The syslog package is not available for Windows. This adapter is needed // The syslog package is not available for Windows. This adapter is needed
// to handle windows build. // to handle windows build.
} }
func AddSyslogHookToLogger(logger *log.Logger) {
// The syslog package is not available for Windows. This adapter is needed
// to handle windows build.
}