From f296956c6f670e736e5c98ef470539b7d972ab5d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 13 Feb 2026 12:52:28 +0100 Subject: [PATCH] Refactor roundtrip AddPeer to reduce cognitive complexity and line count --- proxy/internal/roundtrip/netbird.go | 152 ++++++++++++----------- proxy/internal/roundtrip/netbird_test.go | 73 +++++++++++ 2 files changed, 150 insertions(+), 75 deletions(-) diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index c32e6ee0c..00531a25c 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -38,6 +38,11 @@ type domainInfo struct { serviceID string } +type domainNotification struct { + domain domain.Domain + serviceID string +} + // clientEntry holds an embedded NetBird client and tracks which domains use it. type clientEntry struct { client *embed.Client @@ -114,6 +119,30 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma return nil } + entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID) + if err != nil { + n.clientsMux.Unlock() + return err + } + + n.clients[accountID] = entry + n.clientsMux.Unlock() + + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": d, + }).Info("created new client for account") + + // Attempt to start the client in the background; if this fails we will + // retry on the first request via RoundTrip. + go n.runClientStartup(ctx, accountID, entry.client) + + return nil +} + +// createClientEntry generates a WireGuard keypair, authenticates with management, +// and creates an embedded NetBird client. Must be called with clientsMux held. +func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) { n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, @@ -121,8 +150,7 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { - n.clientsMux.Unlock() - return fmt.Errorf("generate wireguard private key: %w", err) + return nil, fmt.Errorf("generate wireguard private key: %w", err) } publicKey := privateKey.PublicKey() @@ -132,7 +160,6 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma "public_key": publicKey.String(), }).Debug("authenticating new proxy peer with management") - // Authenticate with management using the one-time token and send public key resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{ ServiceId: serviceID, AccountId: string(accountID), @@ -141,16 +168,14 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma Cluster: n.proxyAddr, }) if err != nil { - n.clientsMux.Unlock() - return fmt.Errorf("authenticate proxy peer with management: %w", err) + return nil, fmt.Errorf("authenticate proxy peer with management: %w", err) } if resp != nil && !resp.GetSuccess() { - n.clientsMux.Unlock() errMsg := "unknown error" if resp.ErrorMessage != nil { errMsg = *resp.ErrorMessage } - return fmt.Errorf("proxy peer authentication failed: %s", errMsg) + return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg) } n.logger.WithFields(log.Fields{ @@ -176,14 +201,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma WireguardPort: &n.wgPort, }) if err != nil { - n.clientsMux.Unlock() - return fmt.Errorf("create netbird client: %w", err) + return nil, fmt.Errorf("create netbird client: %w", err) } // Create a transport using the client dialer. We do this instead of using // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. - entry = &clientEntry{ + return &clientEntry{ client: client, domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}}, transport: &http.Transport{ @@ -196,75 +220,53 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma }, createdAt: time.Now(), started: false, + }, nil +} + +// runClientStartup starts the client and notifies registered domains on success. +func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) { + startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := client.Start(startCtx); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request") + } else { + n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client") + } + return + } + + // Mark client as started and collect domains to notify outside the lock. + n.clientsMux.Lock() + entry, exists := n.clients[accountID] + if exists { + entry.started = true + } + var domainsToNotify []domainNotification + if exists { + for dom, info := range entry.domains { + domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID}) + } } - n.clients[accountID] = entry n.clientsMux.Unlock() - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Info("created new client for account") - - // Attempt to start the client in the background, if this fails - // then it is not ideal, but it isn't the end of the world because - // we will try to start the client again before we use it. - go func() { - startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := client.Start(startCtx); err != nil { - if errors.Is(err, context.DeadlineExceeded) { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).Warn("netbird client start timed out, will retry on first request") - } else { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).WithError(err).Error("failed to start netbird client") - } - return + if n.statusNotifier == nil { + return + } + for _, dn := range domainsToNotify { + if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": dn.domain, + }).WithError(err).Warn("failed to notify tunnel connection status") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "domain": dn.domain, + }).Info("notified management about tunnel connection") } - - // Mark client as started and notify all registered domains - n.clientsMux.Lock() - entry, exists := n.clients[accountID] - if exists { - entry.started = true - } - // Copy domain info while holding lock - var domainsToNotify []struct { - domain domain.Domain - serviceID string - } - if exists { - for dom, info := range entry.domains { - domainsToNotify = append(domainsToNotify, struct { - domain domain.Domain - serviceID string - }{domain: dom, serviceID: info.serviceID}) - } - } - n.clientsMux.Unlock() - - // Notify all domains that they're connected - if n.statusNotifier != nil { - for _, domInfo := range domainsToNotify { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(domInfo.domain), true); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": domInfo.domain, - }).WithError(err).Warn("failed to notify tunnel connection status") - } else { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": domInfo.domain, - }).Info("notified management about tunnel connection") - } - } - } - }() - - return nil + } } // RemovePeer unregisters a domain from an account. The client is only stopped diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index fb7e7fa01..3e76af9da 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -3,6 +3,7 @@ package roundtrip import ( "context" "net/http" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -20,6 +21,31 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy return &proto.CreateProxyPeerResponse{Success: true}, nil } +type mockStatusNotifier struct { + mu sync.Mutex + statuses []statusCall +} + +type statusCall struct { + accountID string + serviceID string + domain string + connected bool +} + +func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error { + m.mu.Lock() + defer m.mu.Unlock() + m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected}) + return nil +} + +func (m *mockStatusNotifier) calls() []statusCall { + m.mu.Lock() + defer m.mu.Unlock() + return append([]statusCall{}, m.statuses...) +} + // mockNetBird creates a NetBird instance for testing without actually connecting. // It uses an invalid management URL to prevent real connections. func mockNetBird() *NetBird { @@ -253,3 +279,50 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "no peer connection found for account") } + +func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { + notifier := &mockStatusNotifier{} + nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + accountID := types.AccountID("account-1") + + // Add first domain — creates a new client entry. + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + require.NoError(t, err) + + // Manually mark client as started to simulate background startup completing. + nb.clientsMux.Lock() + nb.clients[accountID].started = true + nb.clientsMux.Unlock() + + // Add second domain — should notify immediately since client is already started. + err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + require.NoError(t, err) + + calls := notifier.calls() + require.Len(t, calls, 1) + assert.Equal(t, string(accountID), calls[0].accountID) + assert.Equal(t, "svc-2", calls[0].serviceID) + assert.Equal(t, "domain2.test", calls[0].domain) + assert.True(t, calls[0].connected) +} + +func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { + notifier := &mockStatusNotifier{} + nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + accountID := types.AccountID("account-1") + + err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + require.NoError(t, err) + err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + require.NoError(t, err) + + // Remove one domain — client stays, but disconnection notification fires. + err = nb.RemovePeer(context.Background(), accountID, "domain1.test") + require.NoError(t, err) + assert.True(t, nb.HasClient(accountID)) + + calls := notifier.calls() + require.Len(t, calls, 1) + assert.Equal(t, "domain1.test", calls[0].domain) + assert.False(t, calls[0].connected) +}