diff --git a/proxy/server.go b/proxy/server.go index 096e1be5c..908646726 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -157,41 +157,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { reg := prometheus.NewRegistry() s.meter = metrics.New(reg) - // The very first thing to do should be to connect to the Management server. - // Without this connection, the Proxy cannot do anything. - mgmtURL, err := url.Parse(s.ManagementAddress) + mgmtConn, err := s.dialManagement() if err != nil { - return fmt.Errorf("parse management address: %w", err) - } - creds := insecure.NewCredentials() - // Simple TLS check using management URL. - // Assume management TLS is enabled for gRPC as well if using HTTPS for the API. - if mgmtURL.Scheme == "https" { - certPool, err := x509.SystemCertPool() - if err != nil || certPool == nil { - // Fall back to embedded CAs if no OS-provided ones are available. - certPool = embeddedroots.Get() - } - - creds = credentials.NewTLS(&tls.Config{ - RootCAs: certPool, - }) - } - s.Logger.WithFields(log.Fields{ - "gRPC_address": mgmtURL.Host, - "TLS_enabled": mgmtURL.Scheme == "https", - }).Debug("starting management gRPC client") - mgmtConn, err := grpc.NewClient(mgmtURL.Host, - grpc.WithTransportCredentials(creds), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 20 * time.Second, - Timeout: 10 * time.Second, - PermitWithoutStream: true, - }), - proxygrpc.WithProxyToken(s.ProxyToken), - ) - if err != nil { - return fmt.Errorf("could not create management connection: %w", err) + return err } defer func() { if err := mgmtConn.Close(); err != nil { @@ -205,53 +173,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // to proxy over. s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) - // When generating ACME certificates, start a challenge server. - tlsConfig := &tls.Config{} - if s.GenerateACMECertificates { - // Default to TLS-ALPN-01 challenge if not specified - if s.ACMEChallengeType == "" { - s.ACMEChallengeType = "tls-alpn-01" - } - s.Logger.WithFields(log.Fields{ - "acme_server": s.ACMEDirectory, - "challenge_type": s.ACMEChallengeType, - }).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod) - - // Only start HTTP server for HTTP-01 challenge type - if s.ACMEChallengeType == "http-01" { - s.http = &http.Server{ - Addr: s.ACMEChallengeAddress, - Handler: s.acme.HTTPHandler(nil), - } - go func() { - if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed") - } - }() - } - tlsConfig = s.acme.TLSConfig() - - // ServerName needs to be set to allow for ACME to work correctly - // when using CNAME URLs to access the proxy. - tlsConfig.ServerName = s.ProxyURL - - s.Logger.WithFields(log.Fields{ - "ServerName": s.ProxyURL, - "challenge_type": s.ACMEChallengeType, - }).Debug("ACME certificate manager configured") - } else { - s.Logger.Debug("ACME certificates disabled, using static certificates with file watching") - certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile) - keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile) - - certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger) - if err != nil { - return fmt.Errorf("initialize certificate watcher: %w", err) - } - go certWatcher.Watch(ctx) - - tlsConfig.GetCertificate = certWatcher.GetCertificate + tlsConfig, err := s.configureTLS(ctx) + if err != nil { + return err } // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. @@ -340,6 +264,91 @@ const ( shutdownServiceTimeout = 5 * time.Second ) +func (s *Server) dialManagement() (*grpc.ClientConn, error) { + mgmtURL, err := url.Parse(s.ManagementAddress) + if err != nil { + return nil, fmt.Errorf("parse management address: %w", err) + } + creds := insecure.NewCredentials() + // Assume management TLS is enabled for gRPC as well if using HTTPS for the API. + if mgmtURL.Scheme == "https" { + certPool, err := x509.SystemCertPool() + if err != nil || certPool == nil { + // Fall back to embedded CAs if no OS-provided ones are available. + certPool = embeddedroots.Get() + } + creds = credentials.NewTLS(&tls.Config{ + RootCAs: certPool, + }) + } + s.Logger.WithFields(log.Fields{ + "gRPC_address": mgmtURL.Host, + "TLS_enabled": mgmtURL.Scheme == "https", + }).Debug("starting management gRPC client") + conn, err := grpc.NewClient(mgmtURL.Host, + grpc.WithTransportCredentials(creds), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 20 * time.Second, + Timeout: 10 * time.Second, + PermitWithoutStream: true, + }), + proxygrpc.WithProxyToken(s.ProxyToken), + ) + if err != nil { + return nil, fmt.Errorf("create management connection: %w", err) + } + return conn, nil +} + +func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { + tlsConfig := &tls.Config{} + if !s.GenerateACMECertificates { + s.Logger.Debug("ACME certificates disabled, using static certificates with file watching") + certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile) + keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile) + + certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger) + if err != nil { + return nil, fmt.Errorf("initialize certificate watcher: %w", err) + } + go certWatcher.Watch(ctx) + tlsConfig.GetCertificate = certWatcher.GetCertificate + return tlsConfig, nil + } + + if s.ACMEChallengeType == "" { + s.ACMEChallengeType = "tls-alpn-01" + } + s.Logger.WithFields(log.Fields{ + "acme_server": s.ACMEDirectory, + "challenge_type": s.ACMEChallengeType, + }).Debug("ACME certificates enabled, configuring certificate manager") + s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod) + + if s.ACMEChallengeType == "http-01" { + s.http = &http.Server{ + Addr: s.ACMEChallengeAddress, + Handler: s.acme.HTTPHandler(nil), + } + go func() { + if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed") + } + }() + } + tlsConfig = s.acme.TLSConfig() + + // ServerName needs to be set to allow for ACME to work correctly + // when using CNAME URLs to access the proxy. + tlsConfig.ServerName = s.ProxyURL + + s.Logger.WithFields(log.Fields{ + "ServerName": s.ProxyURL, + "challenge_type": s.ACMEChallengeType, + }).Debug("ACME certificate manager configured") + return tlsConfig, nil +} + // gracefulShutdown performs a zero-downtime shutdown sequence. It marks the // readiness probe as failing, waits for load balancer propagation, drains // in-flight connections, and then stops all background services. @@ -492,36 +501,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr return fmt.Errorf("receive msg: %w", err) } s.Logger.Debug("Received mapping update, starting processing") - // Process msg updates sequentially to avoid conflict, so block - // additional receiving until this processing is completed. - for _, mapping := range msg.GetMapping() { - s.Logger.WithFields(log.Fields{ - "type": mapping.GetType(), - "domain": mapping.GetDomain(), - "path": mapping.GetPath(), - "id": mapping.GetId(), - }).Debug("Processing mapping update") - switch mapping.GetType() { - case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: - if err := s.addMapping(ctx, mapping); err != nil { - // TODO: Retry this? Or maybe notify the management server that this mapping has failed? - s.Logger.WithFields(log.Fields{ - "service_id": mapping.GetId(), - "domain": mapping.GetDomain(), - "error": err, - }).Error("Error adding new mapping, ignoring this mapping and continuing processing") - } - case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: - if err := s.updateMapping(ctx, mapping); err != nil { - s.Logger.WithFields(log.Fields{ - "service_id": mapping.GetId(), - "domain": mapping.GetDomain(), - }).Errorf("failed to update mapping: %v", err) - } - case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: - s.removeMapping(ctx, mapping) - } - } + s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") if !*initialSyncDone && msg.GetInitialSyncComplete() { @@ -535,6 +515,37 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr } } +func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { + for _, mapping := range mappings { + s.Logger.WithFields(log.Fields{ + "type": mapping.GetType(), + "domain": mapping.GetDomain(), + "path": mapping.GetPath(), + "id": mapping.GetId(), + }).Debug("Processing mapping update") + switch mapping.GetType() { + case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: + if err := s.addMapping(ctx, mapping); err != nil { + // TODO: Retry this? Or maybe notify the management server that this mapping has failed? + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "domain": mapping.GetDomain(), + "error": err, + }).Error("Error adding new mapping, ignoring this mapping and continuing processing") + } + case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: + if err := s.updateMapping(ctx, mapping); err != nil { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "domain": mapping.GetDomain(), + }).Errorf("failed to update mapping: %v", err) + } + case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: + s.removeMapping(ctx, mapping) + } + } +} + func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId())