diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index e6a9f9ca7..402ff651e 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -1,6 +1,7 @@ package roundtrip import ( + "context" "fmt" "net/http" "sync" @@ -16,17 +17,17 @@ type NetBird struct { mgmtAddr string clientsMux sync.RWMutex - clients map[string]*http.Client + clients map[string]*embed.Client } func NewNetBird(mgmtAddr string) *NetBird { return &NetBird{ mgmtAddr: mgmtAddr, - clients: make(map[string]*http.Client), + clients: make(map[string]*embed.Client), } } -func (n *NetBird) AddPeer(domain, key string) error { +func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error { client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + domain, ManagementURL: n.mgmtAddr, @@ -35,16 +36,30 @@ func (n *NetBird) AddPeer(domain, key string) error { if err != nil { return fmt.Errorf("create netbird client: %w", err) } + if err := client.Start(ctx); err != nil { + return fmt.Errorf("start netbird client: %w", err) + } n.clientsMux.Lock() defer n.clientsMux.Unlock() - n.clients[domain] = client.NewHTTPClient() + n.clients[domain] = client return nil } -func (n *NetBird) RemovePeer(domain string) { +func (n *NetBird) RemovePeer(ctx context.Context, domain string) error { + n.clientsMux.RLock() + client, exists := n.clients[domain] + n.clientsMux.RUnlock() + if !exists { + // Mission failed successfully! + return nil + } + if err := client.Stop(ctx); err != nil { + return fmt.Errorf("stop netbird client: %w", err) + } n.clientsMux.Lock() defer n.clientsMux.Unlock() delete(n.clients, domain) + return nil } func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { @@ -57,5 +72,5 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { if !exists { return nil, fmt.Errorf("no peer connection found for host: %s", req.Host) } - return client.Do(req) + return client.NewHTTPClient().Do(req) } diff --git a/proxy/server.go b/proxy/server.go index 067716588..1b9a56fca 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -186,7 +186,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: s.updateMapping(ctx, mapping) case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: - s.removeMapping(mapping) + s.removeMapping(ctx, mapping) } } } @@ -194,7 +194,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr } func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { - if err := s.netbird.AddPeer(mapping.GetDomain(), mapping.GetSetupKey()); err != nil { + if err := s.netbird.AddPeer(ctx, mapping.GetDomain(), mapping.GetSetupKey()); err != nil { return fmt.Errorf("create peer for domain %q: %w", mapping.GetDomain(), err) } if s.acme != nil { @@ -245,8 +245,12 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) s.proxy.AddMapping(s.protoToMapping(mapping)) } -func (s *Server) removeMapping(mapping *proto.ProxyMapping) { - s.netbird.RemovePeer(mapping.GetDomain()) +func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { + if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil { + s.ErrorLog.ErrorContext(ctx, "Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist", + "domain", mapping.GetDomain(), + "error", err) + } if s.acme != nil { s.acme.RemoveDomain(mapping.GetDomain()) }