use a defined logger

this should avoid issues with the embedded
client also attempting to use the same global logger
This commit is contained in:
Alisdair MacLeod
2026-01-30 16:31:32 +00:00
parent 5345d716ee
commit 3a6f364b03
9 changed files with 115 additions and 47 deletions

View File

@@ -139,7 +139,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
group, err := s.keyStore.GetGroupByName(ctx, rp.Name, rp.AccountID)
if err != nil {
// TODO: log this?
log.WithFields(log.Fields{
"proxy": rp.Name,
"account": rp.AccountID,
}).WithError(err).Error("Failed to get group by name")
continue
}
@@ -156,7 +159,11 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
false,
)
if err != nil {
// TODO: how to handle this?
log.WithFields(log.Fields{
"proxy": rp.Name,
"account": rp.AccountID,
"group": group.ID,
}).WithError(err).Error("Failed to create setup key")
continue
}
@@ -168,7 +175,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
),
},
}); err != nil {
// TODO: log the error, maybe retry?
log.WithError(err).Error("Failed to send proxy mapping")
continue
}
}

View File

@@ -74,12 +74,14 @@ func main() {
if debug {
level = "debug"
}
logger := log.New()
_ = util.InitLog(level, util.LogConsole)
_ = util.InitLogger(logger, level, util.LogConsole)
log.Infof("configured log level: %s", level)
srv := proxy.Server{
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
ProxyURL: url,

View File

@@ -16,11 +16,16 @@ type gRPCClient interface {
type Logger struct {
client gRPCClient
logger *log.Logger
}
func NewLogger(client gRPCClient) *Logger {
func NewLogger(client gRPCClient, logger *log.Logger) *Logger {
if logger == nil {
logger = log.StandardLogger()
}
return &Logger{
client: client,
logger: logger,
}
}
@@ -68,7 +73,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
},
}); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log.
log.WithFields(log.Fields{
l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId,
"host": entry.Host,
"path": entry.Path,

View File

@@ -12,7 +12,6 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
)
@@ -165,7 +164,6 @@ func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
// Exchange code for token
token, err := o.oauthConfig.Exchange(r.Context(), code)
if err != nil {
log.WithError(err).Error("Token exchange failed")
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}

View File

@@ -2,6 +2,7 @@ package roundtrip
import (
"context"
"errors"
"fmt"
"io"
"net/http"
@@ -19,14 +20,19 @@ const deviceNamePrefix = "ingress-"
// backed by underlying NetBird connections.
type NetBird struct {
mgmtAddr string
logger *log.Logger
clientsMux sync.RWMutex
clients map[string]*embed.Client
}
func NewNetBird(mgmtAddr string) *NetBird {
func NewNetBird(mgmtAddr string, logger *log.Logger) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
logger: logger,
clients: make(map[string]*embed.Client),
}
}
@@ -37,13 +43,29 @@ func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error {
ManagementURL: n.mgmtAddr,
SetupKey: key,
LogOutput: io.Discard,
BlockInbound: true,
})
if err != nil {
return fmt.Errorf("create netbird client: %w", err)
}
if err := client.Start(ctx); err != nil {
return fmt.Errorf("start netbird client: %w", err)
}
// Attempt to start the client in the background, if this fails
// then it is not ideal, but it isn't the end of the world because
// we will try to start the client again before we use it.
go func() {
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err = client.Start(startCtx)
switch {
case errors.Is(err, context.DeadlineExceeded):
n.logger.Debug("netbird client timed out")
// This is not ideal, but we will try again later.
return
case err != nil:
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
}
}()
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.clients[domain] = client
@@ -78,7 +100,20 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("no peer connection found for host: %s", req.Host)
}
log.WithFields(log.Fields{
// Attempt to start the client, if the client is already running then
// it will return an error that we ignore, if this hits a timeout then
// this request is unprocessable.
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
defer cancel()
err := client.Start(startCtx)
switch {
case errors.Is(err, embed.ErrClientAlreadyStarted):
break
case err != nil:
return nil, fmt.Errorf("start netbird client: %w", err)
}
n.logger.WithFields(log.Fields{
"host": req.Host,
"url": req.URL.String(),
"requestURI": req.RequestURI,

View File

@@ -50,6 +50,7 @@ type Server struct {
startTime time.Time
ID string
Logger *log.Logger
Version string
ProxyURL string
ManagementAddress string
@@ -71,6 +72,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.Version = "dev"
}
// If no logger is specified fallback to the standard logger.
if s.Logger == nil {
s.Logger = log.StandardLogger()
}
// The very first thing to do should be to connect to the Management server.
// Without this connection, the Proxy cannot do anything.
mgmtURL, err := url.Parse(s.ManagementAddress)
@@ -91,7 +97,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
RootCAs: certPool,
})
}
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"gRPC_address": mgmtURL.Host,
"TLS_enabled": mgmtURL.Scheme == "https",
}).Debug("starting management gRPC client")
@@ -111,12 +117,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress)
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger)
// When generating ACME certificates, start a challenge server.
tlsConfig := &tls.Config{}
if s.GenerateACMECertificates {
log.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory)
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
@@ -126,7 +132,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
if err := s.http.ListenAndServe(); err != nil {
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
for range time.Tick(10 * time.Second) {
log.WithError(err).Error("ACME HTTP-01 challenge server error")
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error")
}
}
}()
@@ -142,11 +148,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// when using CNAME URLs to access the proxy.
tlsConfig.ServerName = s.ProxyURL
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"ServerName": s.ProxyURL,
}).Debug("started ACME challenge server")
} else {
log.Debug("ACME certificates disabled, using static certificates")
s.Logger.Debug("ACME certificates disabled, using static certificates")
// Otherwise pull some certificates from expected locations.
cert, err := tls.LoadX509KeyPair(
filepath.Join(s.CertificateDirectory, "tls.crt"),
@@ -165,7 +171,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.auth = auth.NewMiddleware()
// Configure Access logs to management server.
accessLog := accesslog.NewLogger(mgmtClient)
accessLog := accesslog.NewLogger(mgmtClient, s.Logger)
// Finally, start the reverse proxy.
s.https = &http.Server{
@@ -179,23 +185,23 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {
b := backoff.New(0, 0)
for {
log.Debug("Getting mapping updates from management server")
s.Logger.Debug("Getting mapping updates from management server")
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
})
if err != nil {
log.WithError(err).Warn("Could not get mapping updates, will retry")
s.Logger.WithError(err).Warn("Could not get mapping updates, will retry")
backoffDuration := b.Duration()
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"backoff": backoffDuration,
"error": err,
}).Error("Unable to create mapping client to management server, retrying connection after backoff")
time.Sleep(backoffDuration)
continue
}
log.Debug("Got mapping updates client from management server")
s.Logger.Debug("Got mapping updates client from management server")
err = s.handleMappingStream(ctx, mappingClient)
backoffDuration := b.Duration()
switch {
@@ -203,17 +209,17 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
errors.Is(err, context.DeadlineExceeded):
// Context is telling us that it is time to quit so gracefully exit here.
// No need to log the error as it is a parent context causing this return.
log.Debugf("Got context error, will exit loop: %v", err)
s.Logger.Debugf("Got context error, will exit loop: %v", err)
return
case err != nil:
// Log the error and then retry the connection.
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"backoff": backoffDuration,
"error": err,
}).Error("Error processing mapping stream from management server, retrying connection after backoff")
default:
// TODO: should this really be at error level? Maybe, if you start getting lots of these this could be an indication of connectivity issues.
log.WithField("backoff", backoffDuration).Error("Management mapping connection terminated by the server, retrying connection after backoff")
s.Logger.WithField("backoff", backoffDuration).Error("Management mapping connection terminated by the server, retrying connection after backoff")
}
time.Sleep(backoffDuration)
}
@@ -236,11 +242,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
// Something has gone horribly wrong, return and hope the parent retries the connection.
return fmt.Errorf("receive msg: %w", err)
}
log.Debug("Received mapping update, starting processing")
s.Logger.Debug("Received mapping update, starting processing")
// Process msg updates sequentially to avoid conflict, so block
// additional receiving until this processing is completed.
for _, mapping := range msg.GetMapping() {
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"path": mapping.GetPath(),
@@ -250,7 +256,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
"error": err,
@@ -262,6 +268,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.removeMapping(ctx, mapping)
}
}
s.Logger.Debug("Processing mapping update completed")
}
}
}
@@ -305,7 +312,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
OIDCScopes: oidc.GetOidcScopes(),
})
if err != nil {
log.WithError(err).Error("Failed to create OIDC scheme")
s.Logger.WithError(err).Error("Failed to create OIDC scheme")
} else {
schemes = append(schemes, scheme)
}
@@ -319,7 +326,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil {
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"domain": mapping.GetDomain(),
"error": err,
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
@@ -337,7 +344,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
targetURL, err := url.Parse(pathMapping.GetTarget())
if err != nil {
// TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure?
log.WithFields(log.Fields{
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"account_id": mapping.GetAccountId(),
"domain": mapping.GetDomain(),

View File

@@ -30,9 +30,14 @@ var (
// InitLog parses and sets log-level input
func InitLog(logLevel string, logs ...string) error {
return InitLogger(log.StandardLogger(), logLevel, logs...)
}
// InitLogger parses and sets log-level input for a logrus logger
func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
level, err := log.ParseLevel(logLevel)
if err != nil {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
logger.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err
}
var writers []io.Writer
@@ -41,34 +46,34 @@ func InitLog(logLevel string, logs ...string) error {
for _, logPath := range logs {
switch logPath {
case LogSyslog:
AddSyslogHook()
AddSyslogHookToLogger(logger)
logFmt = "syslog"
case LogConsole:
writers = append(writers, os.Stderr)
case "":
log.Warnf("empty log path received: %#v", logPath)
logger.Warnf("empty log path received: %#v", logPath)
default:
writers = append(writers, newRotatedOutput(logPath))
}
}
if len(writers) > 1 {
log.SetOutput(io.MultiWriter(writers...))
logger.SetOutput(io.MultiWriter(writers...))
} else if len(writers) == 1 {
log.SetOutput(writers[0])
logger.SetOutput(writers[0])
}
switch logFmt {
case "json":
formatter.SetJSONFormatter(log.StandardLogger())
formatter.SetJSONFormatter(logger)
case "syslog":
formatter.SetSyslogFormatter(log.StandardLogger())
formatter.SetSyslogFormatter(logger)
default:
formatter.SetTextFormatter(log.StandardLogger())
formatter.SetTextFormatter(logger)
}
log.SetLevel(level)
logger.SetLevel(level)
setGRPCLibLogger()
setGRPCLibLogger(logger)
return nil
}
@@ -96,8 +101,8 @@ func newRotatedOutput(logPath string) io.Writer {
return lumberjackLogger
}
func setGRPCLibLogger() {
logOut := log.StandardLogger().Writer()
func setGRPCLibLogger(logger *log.Logger) {
logOut := logger.Writer()
if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" {
grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, logOut, logOut))
return

View File

@@ -10,10 +10,14 @@ import (
)
func AddSyslogHook() {
AddSyslogHookToLogger(log.StandardLogger())
}
func AddSyslogHookToLogger(logger *log.Logger) {
hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "")
if err != nil {
log.Errorf("Failed creating syslog hook: %s", err)
logger.Errorf("Failed creating syslog hook: %s", err)
}
log.AddHook(hook)
logger.AddHook(hook)
}

View File

@@ -4,3 +4,8 @@ func AddSyslogHook() {
// The syslog package is not available for Windows. This adapter is needed
// to handle windows build.
}
func AddSyslogHookToLogger(logger *log.Logger) {
// The syslog package is not available for Windows. This adapter is needed
// to handle windows build.
}