Merge branch 'fix/log-formatter' into fix/http-redirect

This commit is contained in:
Zoltán Papp
2026-02-13 13:22:39 +01:00
34 changed files with 5542 additions and 922 deletions

View File

@@ -39,10 +39,10 @@ var (
addr string
proxyDomain string
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeChallengeType string
acmeCerts bool
acmeAddr string
acmeDir string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
@@ -123,7 +123,7 @@ func runServer(cmd *cobra.Command, args []string) error {
_ = util.InitLogger(logger, level, util.LogConsole)
log.Infof("configured log level: %s", level)
logger.Infof("configured log level: %s", level)
switch forwardedProto {
case "auto", "http", "https":
@@ -171,7 +171,7 @@ func runServer(cmd *cobra.Command, args []string) error {
defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil {
log.Error(err)
logger.Error(err)
return err
}
return nil

View File

@@ -84,7 +84,7 @@ func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
// nil lockFile means locking is not supported (non-unix).
if lockFile == nil {
return func() {}, nil
return func() { /* no-op: locking unsupported on this platform */ }, nil
}
return func() {
@@ -98,5 +98,5 @@ type noopLocker struct{}
// Lock is a no-op that always succeeds immediately.
func (noopLocker) Lock(context.Context, string) (func(), error) {
return func() {}, nil
return func() { /* no-op: locker disabled */ }, nil
}

View File

@@ -90,10 +90,8 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
if err != nil {
host = r.Host
}
mw.domainsMux.RLock()
config, exists := mw.domains[host]
mw.domainsMux.RUnlock()
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
@@ -103,115 +101,160 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
}
// Set account and service IDs in captured data for access logging.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetServiceId(config.ServiceID)
}
setCapturedIDs(r, config)
// Check for error from OAuth callback (e.g., access denied)
if errCode := r.URL.Query().Get("error"); errCode != "" {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodOIDC.String())
requestID = cd.GetRequestID()
}
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
if mw.handleOAuthCallbackError(w, r) {
return
}
// Check for an existing session cookie (contains JWT)
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return
}
if mw.forwardWithSessionCookie(w, r, host, config, next) {
return
}
// Try to authenticate with each scheme.
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
mw.authenticateWithSchemes(w, r, host, config)
})
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
config, exists := mw.domains[host]
return config, exists
}
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetServiceId(config.ServiceID)
}
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
errCode := r.URL.Query().Get("error")
if errCode == "" {
return false
}
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodOIDC.String())
requestID = cd.GetRequestID()
}
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
}
// forwardWithSessionCookie checks for a valid session cookie and, if found,
// sets the user identity on the request context and forwards to the next handler.
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
cookie, err := r.Cookie(auth.SessionCookieName)
if err != nil {
return false
}
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
attemptedMethod = scheme.Type().String()
}
if token != "" {
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !result.Valid {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
methods[scheme.Type().String()] = promptData
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
attemptedMethod = scheme.Type().String()
}
if token != "" {
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
return
}
methods[scheme.Type().String()] = promptData
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
if attemptedMethod != "" {
cd.SetAuthMethod(attemptedMethod)
}
}
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
}
// handleAuthenticatedToken validates the token, handles denied access, and on
// success sets a session cookie and redirects to the original URL.
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
if attemptedMethod != "" {
cd.SetAuthMethod(attemptedMethod)
}
cd.SetAuthMethod(scheme.Type().String())
}
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !result.Valid {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.

View File

@@ -83,6 +83,10 @@ func (c *Client) printHealth(data map[string]any) {
}
}
c.printHealthClients(data)
}
func (c *Client) printHealthClients(data map[string]any) {
clients, ok := data["clients"].(map[string]any)
if !ok || len(clients) == 0 {
return

View File

@@ -0,0 +1,71 @@
package debug
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
var buf bytes.Buffer
c := NewClient("localhost:8444", false, &buf)
data := map[string]any{
"status": "ok",
"uptime": "1h30m",
"management_connected": true,
"all_clients_healthy": true,
"certs_total": float64(3),
"certs_ready": float64(2),
"certs_pending": float64(1),
"certs_failed": float64(0),
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
"certs_pending_domains": []any{"c.example.com"},
"clients": map[string]any{
"acc-1": map[string]any{
"healthy": true,
"management_connected": true,
"signal_connected": true,
"relays_connected": float64(1),
"relays_total": float64(2),
"peers_connected": float64(3),
"peers_total": float64(5),
"peers_p2p": float64(2),
"peers_relayed": float64(1),
"peers_degraded": float64(0),
},
},
}
c.printHealth(data)
out := buf.String()
assert.Contains(t, out, "Status: ok")
assert.Contains(t, out, "Uptime: 1h30m")
assert.Contains(t, out, "yes") // management_connected
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
assert.Contains(t, out, "a.example.com")
assert.Contains(t, out, "c.example.com")
assert.Contains(t, out, "acc-1")
}
func TestPrintHealth_Minimal(t *testing.T) {
var buf bytes.Buffer
c := NewClient("localhost:8444", false, &buf)
data := map[string]any{
"status": "ok",
"uptime": "5m",
"management_connected": false,
"all_clients_healthy": false,
}
c.printHealth(data)
out := buf.String()
assert.Contains(t, out, "Status: ok")
assert.Contains(t, out, "Uptime: 5m")
assert.NotContains(t, out, "Certificates")
assert.NotContains(t, out, "ACCOUNT ID")
}

View File

@@ -323,7 +323,7 @@ func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler
if metricsHandler != nil {
mux := http.NewServeMux()
mux.Handle("/metrics", metricsHandler)
mux.Handle("/", checker.Handler())
mux.Handle("/", handler)
handler = mux
}

View File

@@ -404,3 +404,70 @@ func TestChecker_Handler_Full(t *testing.T) {
// Clients may be empty map when no clients exist.
assert.Empty(t, resp.Clients)
}
func TestChecker_SetShuttingDown(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
checker.SetShuttingDown()
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
}
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
checker.SetShuttingDown()
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
}
func TestNewServer_WithMetricsHandler(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("metrics"))
})
srv := NewServer(":0", checker, nil, metricsHandler)
require.NotNil(t, srv)
// Verify health endpoint still works through the mux.
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify metrics endpoint is mounted.
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec = httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "metrics", rec.Body.String())
}
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
srv := NewServer(":0", checker, nil, nil)
require.NotNil(t, srv)
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -320,7 +320,8 @@ func getRequestID(r *http.Request) string {
// status code, and component status based on the error type.
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
switch {
case errors.Is(err, context.DeadlineExceeded):
case errors.Is(err, context.DeadlineExceeded),
isNetTimeout(err):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
@@ -356,12 +357,6 @@ func classifyProxyError(err error) (title, message string, code int, status web.
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
case isNetTimeout(err):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
web.ErrorStatus{Proxy: true, Destination: false}
}
return "Connection Error",

View File

@@ -38,6 +38,11 @@ type domainInfo struct {
serviceID string
}
type domainNotification struct {
domain domain.Domain
serviceID string
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
type clientEntry struct {
client *embed.Client
@@ -114,6 +119,30 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
@@ -121,8 +150,7 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("generate wireguard private key: %w", err)
return nil, fmt.Errorf("generate wireguard private key: %w", err)
}
publicKey := privateKey.PublicKey()
@@ -132,7 +160,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
// Authenticate with management using the one-time token and send public key
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
AccountId: string(accountID),
@@ -141,16 +168,14 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
Cluster: n.proxyAddr,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("authenticate proxy peer with management: %w", err)
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
}
if resp != nil && !resp.GetSuccess() {
n.clientsMux.Unlock()
errMsg := "unknown error"
if resp.ErrorMessage != nil {
errMsg = *resp.ErrorMessage
}
return fmt.Errorf("proxy peer authentication failed: %s", errMsg)
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
}
n.logger.WithFields(log.Fields{
@@ -176,14 +201,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
WireguardPort: &n.wgPort,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("create netbird client: %w", err)
return nil, fmt.Errorf("create netbird client: %w", err)
}
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
entry = &clientEntry{
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
@@ -196,75 +220,53 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
},
createdAt: time.Now(),
started: false,
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
}
return
}
// Mark client as started and collect domains to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
}
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background, if this fails
// then it is not ideal, but it isn't the end of the world because
// we will try to start the client again before we use it.
go func() {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Error("failed to start netbird client")
}
return
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).Info("notified management about tunnel connection")
}
// Mark client as started and notify all registered domains
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
// Copy domain info while holding lock
var domainsToNotify []struct {
domain domain.Domain
serviceID string
}
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, struct {
domain domain.Domain
serviceID string
}{domain: dom, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
// Notify all domains that they're connected
if n.statusNotifier != nil {
for _, domInfo := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).Info("notified management about tunnel connection")
}
}
}
}()
return nil
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped

View File

@@ -3,6 +3,7 @@ package roundtrip
import (
"context"
"net/http"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -20,6 +21,31 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
}
type statusCall struct {
accountID string
serviceID string
domain string
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
return nil
}
func (m *mockStatusNotifier) calls() []statusCall {
m.mu.Lock()
defer m.mu.Unlock()
return append([]statusCall{}, m.statuses...)
}
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
@@ -253,3 +279,50 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer connection found for account")
}
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
nb.clientsMux.Lock()
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.True(t, calls[0].connected)
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.False(t, calls[0].connected)
}

View File

@@ -157,41 +157,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
reg := prometheus.NewRegistry()
s.meter = metrics.New(reg)
// The very first thing to do should be to connect to the Management server.
// Without this connection, the Proxy cannot do anything.
mgmtURL, err := url.Parse(s.ManagementAddress)
mgmtConn, err := s.dialManagement()
if err != nil {
return fmt.Errorf("parse management address: %w", err)
}
creds := insecure.NewCredentials()
// Simple TLS check using management URL.
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
if mgmtURL.Scheme == "https" {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
// Fall back to embedded CAs if no OS-provided ones are available.
certPool = embeddedroots.Get()
}
creds = credentials.NewTLS(&tls.Config{
RootCAs: certPool,
})
}
s.Logger.WithFields(log.Fields{
"gRPC_address": mgmtURL.Host,
"TLS_enabled": mgmtURL.Scheme == "https",
}).Debug("starting management gRPC client")
mgmtConn, err := grpc.NewClient(mgmtURL.Host,
grpc.WithTransportCredentials(creds),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 20 * time.Second,
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
proxygrpc.WithProxyToken(s.ProxyToken),
)
if err != nil {
return fmt.Errorf("could not create management connection: %w", err)
return err
}
defer func() {
if err := mgmtConn.Close(); err != nil {
@@ -205,54 +173,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
// When generating ACME certificates, start a challenge server.
tlsConfig := &tls.Config{}
if s.GenerateACMECertificates {
// Default to TLS-ALPN-01 challenge if not specified
if s.ACMEChallengeType == "" {
s.ACMEChallengeType = "tls-alpn-01"
}
s.Logger.WithFields(log.Fields{
"acme_server": s.ACMEDirectory,
"challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
// Only start HTTP server for HTTP-01 challenge type
if s.ACMEChallengeType == "http-01" {
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
Handler: s.acme.HTTPHandler(nil),
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
}
go func() {
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
}
}()
}
tlsConfig = s.acme.TLSConfig()
// ServerName needs to be set to allow for ACME to work correctly
// when using CNAME URLs to access the proxy.
tlsConfig.ServerName = s.ProxyURL
s.Logger.WithFields(log.Fields{
"ServerName": s.ProxyURL,
"challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificate manager configured")
} else {
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
if err != nil {
return fmt.Errorf("initialize certificate watcher: %w", err)
}
go certWatcher.Watch(ctx)
tlsConfig.GetCertificate = certWatcher.GetCertificate
tlsConfig, err := s.configureTLS(ctx)
if err != nil {
return err
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
@@ -356,6 +279,92 @@ const (
shutdownServiceTimeout = 5 * time.Second
)
func (s *Server) dialManagement() (*grpc.ClientConn, error) {
mgmtURL, err := url.Parse(s.ManagementAddress)
if err != nil {
return nil, fmt.Errorf("parse management address: %w", err)
}
creds := insecure.NewCredentials()
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
if mgmtURL.Scheme == "https" {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
// Fall back to embedded CAs if no OS-provided ones are available.
certPool = embeddedroots.Get()
}
creds = credentials.NewTLS(&tls.Config{
RootCAs: certPool,
})
}
s.Logger.WithFields(log.Fields{
"gRPC_address": mgmtURL.Host,
"TLS_enabled": mgmtURL.Scheme == "https",
}).Debug("starting management gRPC client")
conn, err := grpc.NewClient(mgmtURL.Host,
grpc.WithTransportCredentials(creds),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 20 * time.Second,
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
proxygrpc.WithProxyToken(s.ProxyToken),
)
if err != nil {
return nil, fmt.Errorf("create management connection: %w", err)
}
return conn, nil
}
func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
tlsConfig := &tls.Config{}
if !s.GenerateACMECertificates {
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
if err != nil {
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
}
go certWatcher.Watch(ctx)
tlsConfig.GetCertificate = certWatcher.GetCertificate
return tlsConfig, nil
}
if s.ACMEChallengeType == "" {
s.ACMEChallengeType = "tls-alpn-01"
}
s.Logger.WithFields(log.Fields{
"acme_server": s.ACMEDirectory,
"challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
if s.ACMEChallengeType == "http-01" {
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
Handler: s.acme.HTTPHandler(nil),
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
}
go func() {
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
}
}()
}
tlsConfig = s.acme.TLSConfig()
// ServerName needs to be set to allow for ACME to work correctly
// when using CNAME URLs to access the proxy.
tlsConfig.ServerName = s.ProxyURL
s.Logger.WithFields(log.Fields{
"ServerName": s.ProxyURL,
"challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificate manager configured")
return tlsConfig, nil
}
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
// readiness probe as failing, waits for load balancer propagation, drains
// in-flight connections, and then stops all background services.
@@ -508,36 +517,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
return fmt.Errorf("receive msg: %w", err)
}
s.Logger.Debug("Received mapping update, starting processing")
// Process msg updates sequentially to avoid conflict, so block
// additional receiving until this processing is completed.
for _, mapping := range msg.GetMapping() {
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"path": mapping.GetPath(),
"id": mapping.GetId(),
}).Debug("Processing mapping update")
switch mapping.GetType() {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
"error": err,
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
}
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
if err := s.updateMapping(ctx, mapping); err != nil {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Errorf("failed to update mapping: %v", err)
}
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
s.removeMapping(ctx, mapping)
}
}
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
if !*initialSyncDone && msg.GetInitialSyncComplete() {
@@ -551,6 +531,37 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"path": mapping.GetPath(),
"id": mapping.GetId(),
}).Debug("Processing mapping update")
switch mapping.GetType() {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
// TODO: Retry this? Or maybe notify the management server that this mapping has failed?
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
"error": err,
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
}
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
if err := s.updateMapping(ctx, mapping); err != nil {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Errorf("failed to update mapping: %v", err)
}
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
s.removeMapping(ctx, mapping)
}
}
}
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
d := domain.Domain(mapping.GetDomain())
accountID := types.AccountID(mapping.GetAccountId())