ignore ports when performing proxy mapping lookups

This commit is contained in:
Alisdair MacLeod
2026-02-02 14:39:13 +00:00
parent fa6ff005f2
commit a73ee47557
3 changed files with 22 additions and 6 deletions

View File

@@ -6,6 +6,7 @@ import (
_ "embed" _ "embed"
"encoding/base64" "encoding/base64"
"html/template" "html/template"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -92,8 +93,12 @@ func NewMiddleware() *Middleware {
func (mw *Middleware) Protect(next http.Handler) http.Handler { func (mw *Middleware) Protect(next http.Handler) http.Handler {
tmpl := template.Must(template.New("auth").Parse(authTemplate)) tmpl := template.Must(template.New("auth").Parse(authTemplate))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
mw.domainsMux.RLock() mw.domainsMux.RLock()
schemes, exists := mw.domains[r.Host] schemes, exists := mw.domains[host]
mw.domainsMux.RUnlock() mw.domainsMux.RUnlock()
// Domains that are not configured here or have no authentication schemes applied should simply pass through. // Domains that are not configured here or have no authentication schemes applied should simply pass through.

View File

@@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"net"
"net/http" "net/http"
"net/url" "net/url"
"sort" "sort"
@@ -25,7 +26,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
return nil, "", "", false return nil, "", "", false
} }
defer p.mappingsMux.RUnlock() defer p.mappingsMux.RUnlock()
m, exists := p.mappings[req.Host]
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
}
m, exists := p.mappings[host]
if !exists { if !exists {
return nil, "", "", false return nil, "", "", false
} }

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -90,14 +91,18 @@ func (n *NetBird) RemovePeer(ctx context.Context, domain string) error {
} }
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
}
n.clientsMux.RLock() n.clientsMux.RLock()
client, exists := n.clients[req.Host] client, exists := n.clients[host]
// Immediately unlock after retrieval here rather than defer to avoid // Immediately unlock after retrieval here rather than defer to avoid
// the call to client.Do blocking other clients being used whilst one // the call to client.Do blocking other clients being used whilst one
// is in use. // is in use.
n.clientsMux.RUnlock() n.clientsMux.RUnlock()
if !exists { if !exists {
return nil, fmt.Errorf("no peer connection found for host: %s", req.Host) return nil, fmt.Errorf("no peer connection found for host: %s", host)
} }
// Attempt to start the client, if the client is already running then // Attempt to start the client, if the client is already running then
@@ -105,7 +110,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
// this request is unprocessable. // this request is unprocessable.
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second) startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
defer cancel() defer cancel()
err := client.Start(startCtx) err = client.Start(startCtx)
switch { switch {
case errors.Is(err, embed.ErrClientAlreadyStarted): case errors.Is(err, embed.ErrClientAlreadyStarted):
break break
@@ -114,7 +119,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
} }
n.logger.WithFields(log.Fields{ n.logger.WithFields(log.Fields{
"host": req.Host, "host": host,
"url": req.URL.String(), "url": req.URL.String(),
"requestURI": req.RequestURI, "requestURI": req.RequestURI,
"method": req.Method, "method": req.Method,