diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 7fd87ee8e..70088d66a 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,44 +182,6 @@ read_enable_proxy() { return 0 } -read_proxy_domain() { - local suggested_proxy="proxy.${BASE_DOMAIN}" - - echo "" > /dev/stderr - echo "NOTE: The proxy domain must be different from the management domain ($NETBIRD_DOMAIN)" > /dev/stderr - echo "to avoid TLS certificate conflicts." > /dev/stderr - echo "" > /dev/stderr - echo "You also need to add a wildcard DNS record for the proxy domain," > /dev/stderr - echo "e.g. *.${suggested_proxy} pointing to the same server domain as $NETBIRD_DOMAIN with a CNAME record." > /dev/stderr - echo "" > /dev/stderr - echo -n "Enter the domain for the NetBird Proxy (e.g. ${suggested_proxy}): " > /dev/stderr - read -r READ_PROXY_DOMAIN < /dev/tty - - if [[ -z "$READ_PROXY_DOMAIN" ]]; then - echo "The proxy domain cannot be empty." > /dev/stderr - read_proxy_domain - return - fi - - if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo ${READ_PROXY_DOMAIN} | grep ${NETBIRD_DOMAIN} > /dev/null - if [[ $? -eq 0 ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo "$READ_PROXY_DOMAIN" - return 0 -} - read_traefik_acme_email() { echo "" > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr @@ -334,7 +296,6 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" - PROXY_DOMAIN="" PROXY_TOKEN="" return 0 } @@ -364,9 +325,6 @@ configure_reverse_proxy() { if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) - if [[ "$ENABLE_PROXY" == "true" ]]; then - PROXY_DOMAIN=$(read_proxy_domain) - fi fi # Handle external Traefik-specific prompts (option 1) @@ -813,7 +771,7 @@ NB_PROXY_MANAGEMENT_ADDRESS=http://netbird-server:80 # Allow insecure gRPC connection to management (required for internal Docker network) NB_PROXY_ALLOW_INSECURE=true # Public URL where this proxy is reachable (used for cluster registration) -NB_PROXY_DOMAIN=$PROXY_DOMAIN +NB_PROXY_DOMAIN=$NETBIRD_DOMAIN NB_PROXY_ADDRESS=:8443 NB_PROXY_TOKEN=$PROXY_TOKEN NB_PROXY_CERTIFICATE_DIRECTORY=/certs @@ -1203,8 +1161,7 @@ print_builtin_traefik_instructions() { echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge." echo " Point your proxy domain to this server's domain address like in the examples below:" echo "" - echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" - echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" + echo " *.$NETBIRD_DOMAIN CNAME $NETBIRD_DOMAIN" echo "" fi return 0 diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 52bd0bb2f..d623c8260 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -108,7 +108,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, expNewNetworkMapAIDs: expIDs, - } } diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index fa6c45856..e1672c2d0 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -86,7 +86,14 @@ func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, acco result := &AccountResult{Account: account, Err: err} for _, req := range requests { - req.ResultChan <- result + if account != nil { + // Shallow copy the account so each goroutine gets its own struct value. + // This prevents data races when callers mutate fields like Policies. + accountCopy := *account + req.ResultChan <- &AccountResult{Account: &accountCopy, Err: err} + } else { + req.ResultChan <- result + } close(req.ResultChan) } } diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index 1eb25cecc..bd4244546 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -368,7 +368,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes( func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) { peerInGroups := false - filteredPeerIDs := make([]string, 0, len(a.Peers)) + filteredPeerIDs := make([]string, 0, len(groups)) seenPeerIds := make(map[string]struct{}, len(groups)) for _, gid := range groups { @@ -378,7 +378,7 @@ func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerI } if group.IsGroupAll() || len(groups) == 1 { - filteredPeerIDs = filteredPeerIDs[:0] + filteredPeerIDs = make([]string, 0, len(group.Peers)) peerInGroups = false for _, pid := range group.Peers { peer, ok := a.Peers[pid] diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index 12b1350b4..6f84c8d30 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -132,7 +132,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { sourcePeers, ) - dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID) + dnsManagementStatus := c.getPeerDNSManagementStatusFromGroups(peerGroups) dnsUpdate := nbdns.Config{ ServiceEnable: dnsManagementStatus, } @@ -150,7 +150,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { customZones = append(customZones, c.AccountZones...) dnsUpdate.CustomZones = customZones - dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID) + dnsUpdate.NameServerGroups = c.getPeerNSGroupsFromGroups(targetPeerID, peerGroups) } return &NetworkMap{ @@ -276,6 +276,16 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peers := make([]*nbpeer.Peer, 0) return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + + protocolStr := string(protocol) + actionStr := string(rule.Action) + dirStr := strconv.Itoa(direction) + portsJoined := strings.Join(rule.Ports, ",") + for _, peer := range groupPeers { if peer == nil { continue @@ -286,21 +296,18 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peersExists[peer.ID] = struct{}{} } - protocol := rule.Protocol - if protocol == PolicyRuleProtocolNetbirdSSH { - protocol = PolicyRuleProtocolTCP - } + peerIP := net.IP(peer.IP).String() fr := FirewallRule{ PolicyID: rule.ID, - PeerIP: net.IP(peer.IP).String(), + PeerIP: peerIP, Direction: direction, - Action: string(rule.Action), - Protocol: string(protocol), + Action: actionStr, + Protocol: protocolStr, } - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + ruleID := rule.ID + peerIP + dirStr + + protocolStr + actionStr + portsJoined if _, ok := rulesExists[ruleID]; ok { continue } @@ -311,13 +318,7 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( continue } - rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{ - ID: rule.ID, - Ports: rule.Ports, - PortRanges: rule.PortRanges, - Protocol: rule.Protocol, - Action: rule.Action, - }, targetPeer)...) + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -393,7 +394,7 @@ func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID str } func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { - var peersToConnect []*nbpeer.Peer + peersToConnect := make([]*nbpeer.Peer, 0, len(aclPeers)) var expiredPeers []*nbpeer.Peer for _, p := range aclPeers { @@ -408,35 +409,35 @@ func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.P return peersToConnect, expiredPeers } -func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := c.GetPeerGroups(peerID) - enabled := true +func (c *NetworkMapComponents) getPeerDNSManagementStatusFromGroups(peerGroups map[string]struct{}) bool { for _, groupID := range c.DNSSettings.DisabledManagementGroups { if _, found := peerGroups[groupID]; found { - enabled = false - break + return false } } - return enabled + return true } -func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup { - groupList := c.GetPeerGroups(peerID) - +func (c *NetworkMapComponents) getPeerNSGroupsFromGroups(peerID string, groupList map[string]struct{}) []*nbdns.NameServerGroup { var peerNSGroups []*nbdns.NameServerGroup + targetPeerInfo := c.GetPeerInfo(peerID) + if targetPeerInfo == nil { + return peerNSGroups + } + + peerIPStr := targetPeerInfo.IP.String() + for _, nsGroup := range c.NameServerGroups { if !nsGroup.Enabled { continue } for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - targetPeerInfo := c.GetPeerInfo(peerID) - if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) { + if _, found := groupList[gID]; found { + if !c.peerIsNameserver(peerIPStr, nsGroup) { peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break } + break } } } @@ -444,9 +445,9 @@ func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServe return peerNSGroups } -func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { +func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns.NameServerGroup) bool { for _, ns := range nsGroup.NameServers { - if peerInfo.IP.String() == ns.IP.String() { + if peerIPStr == ns.IP.String() { return true } } @@ -487,14 +488,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute } seenRoute[r.ID] = struct{}{} - routeObj := c.copyRoute(r) - routeObj.Peer = peerInfo.Key + r.Peer = peerInfo.Key if r.Enabled { - enabledRoutes = append(enabledRoutes, routeObj) + enabledRoutes = append(enabledRoutes, r) return } - disabledRoutes = append(disabledRoutes, routeObj) + disabledRoutes = append(disabledRoutes, r) } for _, r := range c.Routes { @@ -508,7 +508,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute continue } - newPeerRoute := c.copyRoute(r) + newPeerRoute := r.Copy() newPeerRoute.Peer = id newPeerRoute.PeerGroups = nil newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) @@ -517,50 +517,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute } } if r.Peer == peerID { - takeRoute(c.copyRoute(r)) + takeRoute(r.Copy()) } } return enabledRoutes, disabledRoutes } -func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route { - var groups, accessControlGroups, peerGroups []string - var domains domain.List - - if r.Groups != nil { - groups = append([]string{}, r.Groups...) - } - if r.AccessControlGroups != nil { - accessControlGroups = append([]string{}, r.AccessControlGroups...) - } - if r.PeerGroups != nil { - peerGroups = append([]string{}, r.PeerGroups...) - } - if r.Domains != nil { - domains = append(domain.List{}, r.Domains...) - } - - return &route.Route{ - ID: r.ID, - AccountID: r.AccountID, - Network: r.Network, - NetworkType: r.NetworkType, - Description: r.Description, - Peer: r.Peer, - PeerID: r.PeerID, - Metric: r.Metric, - Masquerade: r.Masquerade, - NetID: r.NetID, - Enabled: r.Enabled, - Groups: groups, - AccessControlGroups: accessControlGroups, - PeerGroups: peerGroups, - Domains: domains, - KeepRoute: r.KeepRoute, - SkipAutoApply: r.SkipAutoApply, - } -} func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { var filteredRoutes []*route.Route diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index ebc15314b..b1e532e83 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -7,9 +7,12 @@ import ( "encoding/asn1" "encoding/base64" "encoding/binary" + "encoding/pem" "fmt" + "math/rand/v2" "net" "slices" + "strings" "sync" "time" @@ -137,7 +140,12 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { // It acquires a distributed lock to prevent multiple replicas from issuing // duplicate ACME requests. The second replica will block until the first // finishes, then find the certificate in the cache. +// ACME and periodic disk reads race; whichever produces a valid certificate +// first wins. This handles cases where locking is unreliable and another +// replica already wrote the cert to the shared cache. func (mgr *Manager) prefetchCertificate(d domain.Domain) { + time.Sleep(time.Duration(rand.IntN(200)) * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -153,26 +161,105 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { defer unlock() } - hello := &tls.ClientHelloInfo{ - ServerName: name, - Conn: &dummyConn{ctx: ctx}, - } - - start := time.Now() - cert, err := mgr.GetCertificate(hello) - elapsed := time.Since(start) - if err != nil { - mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), err) - mgr.setDomainState(d, domainFailed, err.Error()) + if cert, err := mgr.readCertFromDisk(ctx, name); err == nil { + mgr.logger.Infof("certificate for domain %q already on disk, skipping ACME", name) + mgr.recordAndNotify(ctx, d, name, cert, 0) return } - if mgr.metrics != nil { + // Run ACME in a goroutine so we can race it against periodic disk reads. + // autocert uses its own internal context and cannot be cancelled externally. + type acmeResult struct { + cert *tls.Certificate + err error + } + acmeCh := make(chan acmeResult, 1) + hello := &tls.ClientHelloInfo{ServerName: name, Conn: &dummyConn{ctx: ctx}} + go func() { + cert, err := mgr.GetCertificate(hello) + acmeCh <- acmeResult{cert, err} + }() + + start := time.Now() + diskTicker := time.NewTicker(5 * time.Second) + defer diskTicker.Stop() + + for { + select { + case res := <-acmeCh: + elapsed := time.Since(start) + if res.err != nil { + mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), res.err) + mgr.setDomainState(d, domainFailed, res.err.Error()) + return + } + mgr.recordAndNotify(ctx, d, name, res.cert, elapsed) + return + + case <-diskTicker.C: + cert, err := mgr.readCertFromDisk(context.Background(), name) + if err != nil { + continue + } + mgr.logger.Infof("certificate for domain %q appeared on disk after %s", name, time.Since(start).Round(time.Millisecond)) + // Drain the ACME goroutine before marking ready — autocert holds + // an internal write lock on certState while ACME is in flight. + go func() { + select { + case <-acmeCh: + default: + } + mgr.recordAndNotify(context.Background(), d, name, cert, 0) + }() + return + + case <-ctx.Done(): + mgr.logger.Warnf("prefetch certificate for domain %q timed out", name) + mgr.setDomainState(d, domainFailed, ctx.Err().Error()) + return + } + } +} + +// readCertFromDisk reads and parses a certificate directly from the autocert +// DirCache, bypassing autocert's internal certState mutex. Safe to call +// concurrently with an in-flight ACME request for the same domain. +func (mgr *Manager) readCertFromDisk(ctx context.Context, name string) (*tls.Certificate, error) { + if mgr.Cache == nil { + return nil, fmt.Errorf("no cache configured") + } + data, err := mgr.Cache.Get(ctx, name) + if err != nil { + return nil, err + } + privBlock, certsPEM := pem.Decode(data) + if privBlock == nil || !strings.Contains(privBlock.Type, "PRIVATE") { + return nil, fmt.Errorf("no private key in cache for %q", name) + } + cert, err := tls.X509KeyPair(certsPEM, pem.EncodeToMemory(privBlock)) + if err != nil { + return nil, fmt.Errorf("parse cached certificate for %q: %w", name, err) + } + if len(cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("parse leaf for %q: %w", name, err) + } + if time.Now().After(leaf.NotAfter) { + return nil, fmt.Errorf("cached certificate for %q expired at %s", name, leaf.NotAfter) + } + cert.Leaf = leaf + } + return &cert, nil +} + +// recordAndNotify records metrics, marks the domain ready, logs cert details, +// and notifies the cert notifier. +func (mgr *Manager) recordAndNotify(ctx context.Context, d domain.Domain, name string, cert *tls.Certificate, elapsed time.Duration) { + if elapsed > 0 && mgr.metrics != nil { mgr.metrics.RecordCertificateIssuance(elapsed) } - mgr.setDomainState(d, domainReady, "") - now := time.Now() if cert != nil && cert.Leaf != nil { leaf := cert.Leaf @@ -188,11 +275,9 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { } else { mgr.logger.Infof("certificate for domain %q ready in %s", name, elapsed.Round(time.Millisecond)) } - mgr.mu.RLock() info := mgr.domains[d] mgr.mu.RUnlock() - if info != nil && mgr.certNotifier != nil { if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil { mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)