Refactor roundtrip AddPeer to reduce cognitive complexity and line count

This commit is contained in:
Viktor Liu
2026-02-13 12:52:28 +01:00
parent cc5800f46d
commit f296956c6f
2 changed files with 150 additions and 75 deletions

View File

@@ -38,6 +38,11 @@ type domainInfo struct {
serviceID string
}
type domainNotification struct {
domain domain.Domain
serviceID string
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
type clientEntry struct {
client *embed.Client
@@ -114,6 +119,30 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
@@ -121,8 +150,7 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("generate wireguard private key: %w", err)
return nil, fmt.Errorf("generate wireguard private key: %w", err)
}
publicKey := privateKey.PublicKey()
@@ -132,7 +160,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
// Authenticate with management using the one-time token and send public key
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
AccountId: string(accountID),
@@ -141,16 +168,14 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
Cluster: n.proxyAddr,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("authenticate proxy peer with management: %w", err)
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
}
if resp != nil && !resp.GetSuccess() {
n.clientsMux.Unlock()
errMsg := "unknown error"
if resp.ErrorMessage != nil {
errMsg = *resp.ErrorMessage
}
return fmt.Errorf("proxy peer authentication failed: %s", errMsg)
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
}
n.logger.WithFields(log.Fields{
@@ -176,14 +201,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
WireguardPort: &n.wgPort,
})
if err != nil {
n.clientsMux.Unlock()
return fmt.Errorf("create netbird client: %w", err)
return nil, fmt.Errorf("create netbird client: %w", err)
}
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
entry = &clientEntry{
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
@@ -196,75 +220,53 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
},
createdAt: time.Now(),
started: false,
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
}
return
}
// Mark client as started and collect domains to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
}
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background, if this fails
// then it is not ideal, but it isn't the end of the world because
// we will try to start the client again before we use it.
go func() {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Error("failed to start netbird client")
}
return
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).Info("notified management about tunnel connection")
}
// Mark client as started and notify all registered domains
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
// Copy domain info while holding lock
var domainsToNotify []struct {
domain domain.Domain
serviceID string
}
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, struct {
domain domain.Domain
serviceID string
}{domain: dom, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
// Notify all domains that they're connected
if n.statusNotifier != nil {
for _, domInfo := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": domInfo.domain,
}).Info("notified management about tunnel connection")
}
}
}
}()
return nil
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped

View File

@@ -3,6 +3,7 @@ package roundtrip
import (
"context"
"net/http"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -20,6 +21,31 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
}
type statusCall struct {
accountID string
serviceID string
domain string
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
return nil
}
func (m *mockStatusNotifier) calls() []statusCall {
m.mu.Lock()
defer m.mu.Unlock()
return append([]statusCall{}, m.statuses...)
}
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
@@ -253,3 +279,50 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer connection found for account")
}
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
nb.clientsMux.Lock()
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.True(t, calls[0].connected)
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.False(t, calls[0].connected)
}