start and stop netbird embedded clients in proxy

This commit is contained in:
Alisdair MacLeod
2026-01-27 08:33:44 +00:00
parent b0b60b938a
commit 703ef29199
2 changed files with 29 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
package roundtrip
import (
"context"
"fmt"
"net/http"
"sync"
@@ -16,17 +17,17 @@ type NetBird struct {
mgmtAddr string
clientsMux sync.RWMutex
clients map[string]*http.Client
clients map[string]*embed.Client
}
func NewNetBird(mgmtAddr string) *NetBird {
return &NetBird{
mgmtAddr: mgmtAddr,
clients: make(map[string]*http.Client),
clients: make(map[string]*embed.Client),
}
}
func (n *NetBird) AddPeer(domain, key string) error {
func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error {
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + domain,
ManagementURL: n.mgmtAddr,
@@ -35,16 +36,30 @@ func (n *NetBird) AddPeer(domain, key string) error {
if err != nil {
return fmt.Errorf("create netbird client: %w", err)
}
if err := client.Start(ctx); err != nil {
return fmt.Errorf("start netbird client: %w", err)
}
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.clients[domain] = client.NewHTTPClient()
n.clients[domain] = client
return nil
}
func (n *NetBird) RemovePeer(domain string) {
func (n *NetBird) RemovePeer(ctx context.Context, domain string) error {
n.clientsMux.RLock()
client, exists := n.clients[domain]
n.clientsMux.RUnlock()
if !exists {
// Mission failed successfully!
return nil
}
if err := client.Stop(ctx); err != nil {
return fmt.Errorf("stop netbird client: %w", err)
}
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
delete(n.clients, domain)
return nil
}
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -57,5 +72,5 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
if !exists {
return nil, fmt.Errorf("no peer connection found for host: %s", req.Host)
}
return client.Do(req)
return client.NewHTTPClient().Do(req)
}

View File

@@ -186,7 +186,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
s.updateMapping(ctx, mapping)
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
s.removeMapping(mapping)
s.removeMapping(ctx, mapping)
}
}
}
@@ -194,7 +194,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
}
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
if err := s.netbird.AddPeer(mapping.GetDomain(), mapping.GetSetupKey()); err != nil {
if err := s.netbird.AddPeer(ctx, mapping.GetDomain(), mapping.GetSetupKey()); err != nil {
return fmt.Errorf("create peer for domain %q: %w", mapping.GetDomain(), err)
}
if s.acme != nil {
@@ -245,8 +245,12 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
s.proxy.AddMapping(s.protoToMapping(mapping))
}
func (s *Server) removeMapping(mapping *proto.ProxyMapping) {
s.netbird.RemovePeer(mapping.GetDomain())
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil {
s.ErrorLog.ErrorContext(ctx, "Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist",
"domain", mapping.GetDomain(),
"error", err)
}
if s.acme != nil {
s.acme.RemoveDomain(mapping.GetDomain())
}