Merge branch 'prototype/reverse-proxy' into fix/log-formatter

# Conflicts:
#	proxy/server.go
This commit is contained in:
Zoltán Papp
2026-02-13 13:21:50 +01:00
34 changed files with 5542 additions and 922 deletions

View File

@@ -84,7 +84,7 @@ func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
// nil lockFile means locking is not supported (non-unix).
if lockFile == nil {
return func() {}, nil
return func() { /* no-op: locking unsupported on this platform */ }, nil
}
return func() {
@@ -98,5 +98,5 @@ type noopLocker struct{}
// Lock is a no-op that always succeeds immediately.
func (noopLocker) Lock(context.Context, string) (func(), error) {
return func() {}, nil
return func() { /* no-op: locker disabled */ }, nil
}

View File

@@ -90,10 +90,8 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
if err != nil {
host = r.Host
}
mw.domainsMux.RLock()
config, exists := mw.domains[host]
mw.domainsMux.RUnlock()
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
@@ -103,115 +101,160 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
}
// Set account and service IDs in captured data for access logging.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetServiceId(config.ServiceID)
}
setCapturedIDs(r, config)
// Check for error from OAuth callback (e.g., access denied)
if errCode := r.URL.Query().Get("error"); errCode != "" {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodOIDC.String())
requestID = cd.GetRequestID()
}
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
if mw.handleOAuthCallbackError(w, r) {
return
}
// Check for an existing session cookie (contains JWT)
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return
}
if mw.forwardWithSessionCookie(w, r, host, config, next) {
return
}
// Try to authenticate with each scheme.
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
mw.authenticateWithSchemes(w, r, host, config)
})
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
config, exists := mw.domains[host]
return config, exists
}
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetServiceId(config.ServiceID)
}
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
errCode := r.URL.Query().Get("error")
if errCode == "" {
return false
}
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodOIDC.String())
requestID = cd.GetRequestID()
}
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
}
// forwardWithSessionCookie checks for a valid session cookie and, if found,
// sets the user identity on the request context and forwards to the next handler.
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
cookie, err := r.Cookie(auth.SessionCookieName)
if err != nil {
return false
}
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
attemptedMethod = scheme.Type().String()
}
if token != "" {
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !result.Valid {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return
}
methods[scheme.Type().String()] = promptData
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
attemptedMethod = scheme.Type().String()
}
if token != "" {
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
return
}
methods[scheme.Type().String()] = promptData
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
if attemptedMethod != "" {
cd.SetAuthMethod(attemptedMethod)
}
}
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
}
// handleAuthenticatedToken validates the token, handles denied access, and on
// success sets a session cookie and redirects to the original URL.
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
if attemptedMethod != "" {
cd.SetAuthMethod(attemptedMethod)
}
cd.SetAuthMethod(scheme.Type().String())
}
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !result.Valid {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.

View File

@@ -83,6 +83,10 @@ func (c *Client) printHealth(data map[string]any) {
}
}
c.printHealthClients(data)
}
func (c *Client) printHealthClients(data map[string]any) {
clients, ok := data["clients"].(map[string]any)
if !ok || len(clients) == 0 {
return

View File

@@ -0,0 +1,71 @@
package debug
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
var buf bytes.Buffer
c := NewClient("localhost:8444", false, &buf)
data := map[string]any{
"status": "ok",
"uptime": "1h30m",
"management_connected": true,
"all_clients_healthy": true,
"certs_total": float64(3),
"certs_ready": float64(2),
"certs_pending": float64(1),
"certs_failed": float64(0),
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
"certs_pending_domains": []any{"c.example.com"},
"clients": map[string]any{
"acc-1": map[string]any{
"healthy": true,
"management_connected": true,
"signal_connected": true,
"relays_connected": float64(1),
"relays_total": float64(2),
"peers_connected": float64(3),
"peers_total": float64(5),
"peers_p2p": float64(2),
"peers_relayed": float64(1),
"peers_degraded": float64(0),
},
},
}
c.printHealth(data)
out := buf.String()
assert.Contains(t, out, "Status: ok")
assert.Contains(t, out, "Uptime: 1h30m")
assert.Contains(t, out, "yes") // management_connected
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
assert.Contains(t, out, "a.example.com")
assert.Contains(t, out, "c.example.com")
assert.Contains(t, out, "acc-1")
}
func TestPrintHealth_Minimal(t *testing.T) {
var buf bytes.Buffer
c := NewClient("localhost:8444", false, &buf)
data := map[string]any{
"status": "ok",
"uptime": "5m",
"management_connected": false,
"all_clients_healthy": false,
}
c.printHealth(data)
out := buf.String()
assert.Contains(t, out, "Status: ok")
assert.Contains(t, out, "Uptime: 5m")
assert.NotContains(t, out, "Certificates")
assert.NotContains(t, out, "ACCOUNT ID")
}

View File

@@ -323,7 +323,7 @@ func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler
if metricsHandler != nil {
mux := http.NewServeMux()
mux.Handle("/metrics", metricsHandler)
mux.Handle("/", checker.Handler())
mux.Handle("/", handler)
handler = mux
}

View File

@@ -404,3 +404,70 @@ func TestChecker_Handler_Full(t *testing.T) {
// Clients may be empty map when no clients exist.
assert.Empty(t, resp.Clients)
}
func TestChecker_SetShuttingDown(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
checker.SetShuttingDown()
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
}
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
checker.SetShuttingDown()
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
}
func TestNewServer_WithMetricsHandler(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("metrics"))
})
srv := NewServer(":0", checker, nil, metricsHandler)
require.NotNil(t, srv)
// Verify health endpoint still works through the mux.
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify metrics endpoint is mounted.
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec = httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "metrics", rec.Body.String())
}
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
srv := NewServer(":0", checker, nil, nil)
require.NotNil(t, srv)
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -320,7 +320,8 @@ func getRequestID(r *http.Request) string {
// status code, and component status based on the error type.
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
switch {
case errors.Is(err, context.DeadlineExceeded):
case errors.Is(err, context.DeadlineExceeded),
isNetTimeout(err):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
@@ -356,12 +357,6 @@ func classifyProxyError(err error) (title, message string, code int, status web.
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
case isNetTimeout(err):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
web.ErrorStatus{Proxy: true, Destination: false}
}
return "Connection Error",

View File

@@ -38,6 +38,11 @@ type domainInfo struct {
serviceID string
}
type domainNotification struct {
domain domain.Domain
serviceID string
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
type clientEntry struct {
client *embed.Client
@@ -114,6 +119,30 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
@@ -121,8 +150,7 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("generate wireguard private key: %w", err)
return nil, fmt.Errorf("generate wireguard private key: %w", err)
}
publicKey := privateKey.PublicKey()
@@ -132,7 +160,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
// Authenticate with management using the one-time token and send public key
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
AccountId: string(accountID),
@@ -141,16 +168,14 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
Cluster: n.proxyAddr,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("authenticate proxy peer with management: %w", err)
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
}
if resp != nil && !resp.GetSuccess() {
n.clientsMux.Unlock()
errMsg := "unknown error"
if resp.ErrorMessage != nil {
errMsg = *resp.ErrorMessage
}
return fmt.Errorf("proxy peer authentication failed: %s", errMsg)
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
}
n.logger.WithFields(log.Fields{
@@ -176,14 +201,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
WireguardPort: &n.wgPort,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("create netbird client: %w", err)
return nil, fmt.Errorf("create netbird client: %w", err)
}
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
entry = &clientEntry{
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
@@ -196,75 +220,53 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
},
createdAt: time.Now(),
started: false,
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
}
return
}
// Mark client as started and collect domains to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
}
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background, if this fails
// then it is not ideal, but it isn't the end of the world because
// we will try to start the client again before we use it.
go func() {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Error("failed to start netbird client")
}
return
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).Info("notified management about tunnel connection")
}
// Mark client as started and notify all registered domains
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
// Copy domain info while holding lock
var domainsToNotify []struct {
domain domain.Domain
serviceID string
}
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, struct {
domain domain.Domain
serviceID string
}{domain: dom, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
// Notify all domains that they're connected
if n.statusNotifier != nil {
for _, domInfo := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).Info("notified management about tunnel connection")
}
}
}
}()
return nil
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped

View File

@@ -3,6 +3,7 @@ package roundtrip
import (
"context"
"net/http"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -20,6 +21,31 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
}
type statusCall struct {
accountID string
serviceID string
domain string
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
return nil
}
func (m *mockStatusNotifier) calls() []statusCall {
m.mu.Lock()
defer m.mu.Unlock()
return append([]statusCall{}, m.statuses...)
}
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
@@ -253,3 +279,50 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer connection found for account")
}
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
nb.clientsMux.Lock()
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.True(t, calls[0].connected)
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.False(t, calls[0].connected)
}