diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 27b072483..706de7d3c 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -6,7 +6,6 @@ import ( "time" "github.com/rs/xid" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" @@ -14,7 +13,7 @@ import ( func (l *Logger) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Debugf("access log middleware invoked for %s %s", r.Method, r.URL.Path) + l.logger.Debugf("access log middleware invoked for %s %s", r.Method, r.URL.Path) // Use a response writer wrapper so we can access the status code later. sw := &statusWriter{ w: w, diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 5bfeb4156..917672c2d 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -42,11 +42,16 @@ type DomainConfig struct { type Middleware struct { domainsMux sync.RWMutex domains map[string]DomainConfig + logger *log.Logger } -func NewMiddleware() *Middleware { +func NewMiddleware(logger *log.Logger) *Middleware { + if logger == nil { + logger = log.StandardLogger() + } return &Middleware{ domains: make(map[string]DomainConfig), + logger: logger, } } @@ -69,7 +74,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { config, exists := mw.domains[host] mw.domainsMux.RUnlock() - log.Debugf("checking authentication for host: %s, exists: %t", host, exists) + mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists) // Domains that are not configured here or have no authentication schemes applied should simply pass through. if !exists || len(config.Schemes) == 0 { diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index da096102b..d64073941 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -8,6 +8,8 @@ import ( "strings" "sync" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/web" ) @@ -16,6 +18,7 @@ type ReverseProxy struct { transport http.RoundTripper mappingsMux sync.RWMutex mappings map[string]Mapping + logger *log.Logger } // NewReverseProxy configures a new NetBird ReverseProxy. @@ -24,10 +27,14 @@ type ReverseProxy struct { // between requested URLs and targets. // The internal mappings can be modified using the AddMapping // and RemoveMapping functions. -func NewReverseProxy(transport http.RoundTripper) *ReverseProxy { +func NewReverseProxy(transport http.RoundTripper, logger *log.Logger) *ReverseProxy { + if logger == nil { + logger = log.StandardLogger() + } return &ReverseProxy{ transport: transport, mappings: make(map[string]Mapping), + logger: logger, } } diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 9bcf56478..e345bb622 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -7,8 +7,6 @@ import ( "sort" "strings" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/proxy/internal/types" ) @@ -37,7 +35,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string host = h } - log.Debugf("looking for mapping for host: %s, path: %s", host, req.URL.Path) + p.logger.Debugf("looking for mapping for host: %s, path: %s", host, req.URL.Path) m, exists := p.mappings[host] if !exists { return nil, "", "", false diff --git a/proxy/server.go b/proxy/server.go index 885e44b3d..2782911d9 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -210,10 +210,10 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { } // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. - s.proxy = proxy.NewReverseProxy(s.netbird) + s.proxy = proxy.NewReverseProxy(s.netbird, s.Logger) // Configure the authentication middleware. - s.auth = auth.NewMiddleware() + s.auth = auth.NewMiddleware(s.Logger) // Configure Access logs to management server. accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)