mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
add status confirmation for certs and tunnel creation
This commit is contained in:
@@ -5,20 +5,34 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
type certificateNotifier interface {
|
||||
NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
*autocert.Manager
|
||||
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]struct{}
|
||||
domains map[string]struct {
|
||||
accountID string
|
||||
reverseProxyID string
|
||||
}
|
||||
|
||||
certNotifier certificateNotifier
|
||||
}
|
||||
|
||||
func NewManager(certDir, acmeURL string) *Manager {
|
||||
func NewManager(certDir, acmeURL string, notifier certificateNotifier) *Manager {
|
||||
mgr := &Manager{
|
||||
domains: make(map[string]struct{}),
|
||||
domains: make(map[string]struct {
|
||||
accountID string
|
||||
reverseProxyID string
|
||||
}),
|
||||
certNotifier: notifier,
|
||||
}
|
||||
mgr.Manager = &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
@@ -31,19 +45,33 @@ func NewManager(certDir, acmeURL string) *Manager {
|
||||
return mgr
|
||||
}
|
||||
|
||||
func (mgr *Manager) hostPolicy(_ context.Context, domain string) error {
|
||||
func (mgr *Manager) hostPolicy(ctx context.Context, domain string) error {
|
||||
mgr.domainsMux.RLock()
|
||||
defer mgr.domainsMux.RUnlock()
|
||||
if _, exists := mgr.domains[domain]; exists {
|
||||
return nil
|
||||
info, exists := mgr.domains[domain]
|
||||
mgr.domainsMux.RUnlock()
|
||||
if !exists {
|
||||
return fmt.Errorf("unknown domain %q", domain)
|
||||
}
|
||||
return fmt.Errorf("unknown domain %q", domain)
|
||||
|
||||
if mgr.certNotifier != nil {
|
||||
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.reverseProxyID, domain); err != nil {
|
||||
log.Warnf("failed to notify certificate issued for domain %q: %v", domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *Manager) AddDomain(domain string) {
|
||||
func (mgr *Manager) AddDomain(domain, accountID, reverseProxyID string) {
|
||||
mgr.domainsMux.Lock()
|
||||
defer mgr.domainsMux.Unlock()
|
||||
mgr.domains[domain] = struct{}{}
|
||||
mgr.domains[domain] = struct {
|
||||
accountID string
|
||||
reverseProxyID string
|
||||
}{
|
||||
accountID: accountID,
|
||||
reverseProxyID: reverseProxyID,
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) RemoveDomain(domain string) {
|
||||
|
||||
@@ -17,6 +17,10 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-"
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
|
||||
}
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
type NetBird struct {
|
||||
@@ -25,20 +29,23 @@ type NetBird struct {
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[string]*embed.Client
|
||||
|
||||
statusNotifier statusNotifier
|
||||
}
|
||||
|
||||
func NewNetBird(mgmtAddr string, logger *log.Logger) *NetBird {
|
||||
func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &NetBird{
|
||||
mgmtAddr: mgmtAddr,
|
||||
logger: logger,
|
||||
clients: make(map[string]*embed.Client),
|
||||
mgmtAddr: mgmtAddr,
|
||||
logger: logger,
|
||||
clients: make(map[string]*embed.Client),
|
||||
statusNotifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error {
|
||||
func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error {
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + domain,
|
||||
ManagementURL: n.mgmtAddr,
|
||||
@@ -64,6 +71,16 @@ func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error {
|
||||
return
|
||||
case err != nil:
|
||||
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
|
||||
return
|
||||
}
|
||||
|
||||
// Notify management that tunnel is now active
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -73,7 +90,7 @@ func (n *NetBird) AddPeer(ctx context.Context, domain, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, domain string) error {
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error {
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[domain]
|
||||
n.clientsMux.RUnlock()
|
||||
@@ -84,6 +101,16 @@ func (n *NetBird) RemovePeer(ctx context.Context, domain string) error {
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("stop netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Notify management that tunnel is disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection")
|
||||
}
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
delete(n.clients, domain)
|
||||
|
||||
@@ -38,13 +38,13 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
mgmtConn *grpc.ClientConn
|
||||
proxy *proxy.ReverseProxy
|
||||
netbird *roundtrip.NetBird
|
||||
acme *acme.Manager
|
||||
auth *auth.Middleware
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
mgmtClient proto.ProxyServiceClient
|
||||
proxy *proxy.ReverseProxy
|
||||
netbird *roundtrip.NetBird
|
||||
acme *acme.Manager
|
||||
auth *auth.Middleware
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
@@ -64,6 +64,33 @@ type Server struct {
|
||||
OIDCScopes []string
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
func (s *Server) NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error {
|
||||
status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED
|
||||
if connected {
|
||||
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
|
||||
}
|
||||
|
||||
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
|
||||
ReverseProxyId: reverseProxyID,
|
||||
AccountId: accountID,
|
||||
Status: status,
|
||||
CertificateIssued: false,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// NotifyCertificateIssued sends a notification to management that a certificate was issued
|
||||
func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error {
|
||||
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
|
||||
ReverseProxyId: reverseProxyID,
|
||||
AccountId: accountID,
|
||||
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
|
||||
CertificateIssued: true,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.startTime = time.Now()
|
||||
|
||||
@@ -105,7 +132,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
"gRPC_address": mgmtURL.Host,
|
||||
"TLS_enabled": mgmtURL.Scheme == "https",
|
||||
}).Debug("starting management gRPC client")
|
||||
s.mgmtConn, err = grpc.NewClient(mgmtURL.Host,
|
||||
mgmtConn, err := grpc.NewClient(mgmtURL.Host,
|
||||
grpc.WithTransportCredentials(creds),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 20 * time.Second,
|
||||
@@ -116,18 +143,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
}
|
||||
mgmtClient := proto.NewProxyServiceClient(s.mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, mgmtClient)
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger)
|
||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger, s)
|
||||
|
||||
// When generating ACME certificates, start a challenge server.
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory)
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s)
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
@@ -175,7 +202,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.auth = auth.NewMiddleware()
|
||||
|
||||
// Configure Access logs to management server.
|
||||
accessLog := accesslog.NewLogger(mgmtClient, s.Logger)
|
||||
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)
|
||||
|
||||
// Finally, start the reverse proxy.
|
||||
s.https = &http.Server{
|
||||
@@ -279,11 +306,15 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
}
|
||||
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
if err := s.netbird.AddPeer(ctx, mapping.GetDomain(), mapping.GetSetupKey()); err != nil {
|
||||
return fmt.Errorf("create peer for domain %q: %w", mapping.GetDomain(), err)
|
||||
domain := mapping.GetDomain()
|
||||
accountID := mapping.GetAccountId()
|
||||
reverseProxyID := mapping.GetId()
|
||||
|
||||
if err := s.netbird.AddPeer(ctx, domain, mapping.GetSetupKey(), accountID, reverseProxyID); err != nil {
|
||||
return fmt.Errorf("create peer for domain %q: %w", domain, err)
|
||||
}
|
||||
if s.acme != nil {
|
||||
s.acme.AddDomain(mapping.GetDomain())
|
||||
s.acme.AddDomain(domain, accountID, reverseProxyID)
|
||||
}
|
||||
|
||||
// Pass the mapping through to the update function to avoid duplicating the
|
||||
@@ -299,13 +330,12 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
// the auth and proxy mappings.
|
||||
// Note: this does require the management server to always send a
|
||||
// full mapping rather than deltas during a modification.
|
||||
mgmtClient := proto.NewProxyServiceClient(s.mgmtConn)
|
||||
var schemes []auth.Scheme
|
||||
if mapping.GetAuth().GetPassword() {
|
||||
schemes = append(schemes, auth.NewPassword(mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||
schemes = append(schemes, auth.NewPassword(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||
}
|
||||
if mapping.GetAuth().GetPin() {
|
||||
schemes = append(schemes, auth.NewPin(mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||
schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||
}
|
||||
if mapping.GetAuth().GetOidc() != nil {
|
||||
oidc := mapping.GetAuth().GetOidc()
|
||||
@@ -323,14 +353,14 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
}
|
||||
}
|
||||
if mapping.GetAuth().GetLink() {
|
||||
schemes = append(schemes, auth.NewLink(mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||
schemes = append(schemes, auth.NewLink(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||
}
|
||||
s.auth.AddDomain(mapping.GetDomain(), schemes)
|
||||
s.proxy.AddMapping(s.protoToMapping(mapping))
|
||||
}
|
||||
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil {
|
||||
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain(), mapping.GetAccountId(), mapping.GetId()); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"domain": mapping.GetDomain(),
|
||||
"error": err,
|
||||
|
||||
Reference in New Issue
Block a user