diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 3bd611e8f..69776a5fa 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -6,6 +6,7 @@ import ( _ "embed" "encoding/base64" "html/template" + "net" "net/http" "sync" "time" @@ -92,8 +93,12 @@ func NewMiddleware() *Middleware { func (mw *Middleware) Protect(next http.Handler) http.Handler { tmpl := template.Must(template.New("auth").Parse(authTemplate)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } mw.domainsMux.RLock() - schemes, exists := mw.domains[r.Host] + schemes, exists := mw.domains[host] mw.domainsMux.RUnlock() // Domains that are not configured here or have no authentication schemes applied should simply pass through. diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 0fc1dbb57..be9c0ed29 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -1,6 +1,7 @@ package proxy import ( + "net" "net/http" "net/url" "sort" @@ -25,7 +26,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string return nil, "", "", false } defer p.mappingsMux.RUnlock() - m, exists := p.mappings[req.Host] + + host, _, err := net.SplitHostPort(req.Host) + if err != nil { + host = req.Host + } + m, exists := p.mappings[host] if !exists { return nil, "", "", false } diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 3ce5608c9..2933ece28 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "sync" "time" @@ -90,14 +91,18 @@ func (n *NetBird) RemovePeer(ctx context.Context, domain string) error { } func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { + host, _, err := net.SplitHostPort(req.Host) + if err != nil { + host = req.Host + } n.clientsMux.RLock() - client, exists := n.clients[req.Host] + client, exists := n.clients[host] // Immediately unlock after retrieval here rather than defer to avoid // the call to client.Do blocking other clients being used whilst one // is in use. n.clientsMux.RUnlock() if !exists { - return nil, fmt.Errorf("no peer connection found for host: %s", req.Host) + return nil, fmt.Errorf("no peer connection found for host: %s", host) } // Attempt to start the client, if the client is already running then @@ -105,7 +110,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { // this request is unprocessable. startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second) defer cancel() - err := client.Start(startCtx) + err = client.Start(startCtx) switch { case errors.Is(err, embed.ErrClientAlreadyStarted): break @@ -114,7 +119,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { } n.logger.WithFields(log.Fields{ - "host": req.Host, + "host": host, "url": req.URL.String(), "requestURI": req.RequestURI, "method": req.Method,