mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-08 17:09:57 +00:00
Compare commits
6 Commits
fix-statem
...
fix/stale-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4ef7ce237 | ||
|
|
3e2c29a355 | ||
|
|
d4a9b2d302 | ||
|
|
60d2fa08b0 | ||
|
|
1e7b16db0a | ||
|
|
b377d99933 |
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
run: bash -x release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
run: bash -x release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -422,17 +421,12 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -750,17 +749,11 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -22,14 +19,6 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -66,19 +55,6 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -806,6 +806,8 @@ func (g *BundleGenerator) addSyncResponse() error {
|
||||
AllowPartial: true,
|
||||
}
|
||||
|
||||
g.maskSecrets()
|
||||
|
||||
jsonBytes, err := options.Marshal(g.syncResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate json: %w", err)
|
||||
@@ -818,6 +820,27 @@ func (g *BundleGenerator) addSyncResponse() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) maskSecrets() {
|
||||
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if g.syncResponse.NetbirdConfig.Flow != nil {
|
||||
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
|
||||
|
||||
}
|
||||
|
||||
if g.syncResponse.NetbirdConfig.Relay != nil {
|
||||
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
|
||||
}
|
||||
|
||||
for i := range g.syncResponse.NetbirdConfig.Turns {
|
||||
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
|
||||
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStateFile() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path := sm.GetStatePath()
|
||||
|
||||
@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -978,6 +980,7 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
AccessRestrictions: m.AccessRestrictions,
|
||||
Private: m.Private,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
88
management/internals/shared/grpc/proxy_clone_test.go
Normal file
88
management/internals/shared/grpc/proxy_clone_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
|
||||
// copy from the source, since callers assign it individually after cloning.
|
||||
const authTokenField = "AuthToken"
|
||||
|
||||
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
|
||||
// ProxyMapping with a non-zero value and verifies the clone carries each one
|
||||
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
|
||||
// without updating shallowCloneMapping fails this test.
|
||||
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
|
||||
src := &proto.ProxyMapping{}
|
||||
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
|
||||
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
|
||||
|
||||
clone := shallowCloneMapping(src)
|
||||
require.NotNil(t, clone, "clone must not be nil")
|
||||
|
||||
srcVal := reflect.ValueOf(src).Elem()
|
||||
cloneVal := reflect.ValueOf(clone).Elem()
|
||||
|
||||
for _, name := range populated {
|
||||
srcField := srcVal.FieldByName(name).Interface()
|
||||
cloneField := cloneVal.FieldByName(name).Interface()
|
||||
|
||||
if name == authTokenField {
|
||||
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateExportedFields sets a non-zero value on every settable exported field
|
||||
// of the struct and returns their names.
|
||||
func populateExportedFields(t *testing.T, v reflect.Value) []string {
|
||||
t.Helper()
|
||||
|
||||
var names []string
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
structField := typ.Field(i)
|
||||
|
||||
if structField.PkgPath != "" || !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
setNonZero(t, field, structField.Name)
|
||||
names = append(names, structField.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value based on the field kind.
|
||||
func setNonZero(t *testing.T, field reflect.Value, name string) {
|
||||
t.Helper()
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero-" + name)
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Ptr:
|
||||
field.Set(reflect.New(field.Type().Elem()))
|
||||
case reflect.Slice:
|
||||
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
|
||||
case reflect.Map:
|
||||
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
|
||||
}
|
||||
}
|
||||
@@ -1216,6 +1216,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("NetworkResources").
|
||||
Preload("Onboarding").
|
||||
Preload("Services.Targets").
|
||||
Preload("Domains").
|
||||
Take(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
@@ -1302,7 +1303,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12)
|
||||
errChan := make(chan error, 16)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1403,6 +1404,17 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
account.Services = services
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
domains, err := s.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Domains = domains
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +23,63 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
|
||||
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
|
||||
// private service's DNS zone; without the preload the relation is empty and the
|
||||
// service is silently skipped, so a custom domain never resolves on clients.
|
||||
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
|
||||
// paths: it persists two custom domains and asserts the relation comes back
|
||||
// populated, which SynthesizePrivateServiceZones relies on.
|
||||
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := "acct-custom-domains"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
|
||||
|
||||
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
|
||||
require.NoError(t, err, "creating the first custom domain must succeed")
|
||||
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
|
||||
require.NoError(t, err, "creating the second custom domain must succeed")
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
|
||||
|
||||
byDomain := map[string]string{}
|
||||
for _, d := range account.Domains {
|
||||
require.NotNil(t, d)
|
||||
byDomain[d.Domain] = d.TargetCluster
|
||||
}
|
||||
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
|
||||
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
|
||||
}
|
||||
|
||||
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
|
||||
// all fields and nested objects from the database, including deeply nested structures.
|
||||
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {
|
||||
|
||||
@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
zonesByApex := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
@@ -290,19 +290,24 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
serviceDomainZone := a.privateServiceDomainZone(svc)
|
||||
if serviceDomainZone == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByApex[serviceDomainZone]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the cluster apex.
|
||||
// other name under the zone apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Domain: dns.Fqdn(serviceDomainZone),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
zonesByApex[serviceDomainZone] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
@@ -340,8 +345,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
|
||||
for _, zone := range zonesByApex {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -357,6 +362,33 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
return out
|
||||
}
|
||||
|
||||
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
|
||||
// looking at the proxy cluster domain then the custom domains.
|
||||
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
|
||||
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
|
||||
return svc.ProxyCluster
|
||||
}
|
||||
|
||||
// Longest matching custom domain wins
|
||||
zoneName := ""
|
||||
for _, d := range a.Domains {
|
||||
if d == nil || d.TargetCluster != svc.ProxyCluster {
|
||||
continue
|
||||
}
|
||||
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
|
||||
zoneName = d.Domain
|
||||
}
|
||||
}
|
||||
return zoneName
|
||||
}
|
||||
|
||||
func domainFromSuffix(domain, suffix string) bool {
|
||||
if suffix == "" {
|
||||
return false
|
||||
}
|
||||
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -234,6 +235,113 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
|
||||
// is the cluster serving it, and account.Domains holds the registered
|
||||
// custom domain. The synth zone apex must be the registered domain,
|
||||
// not the cluster, or the client's match-only zone never intercepts
|
||||
// the query.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
|
||||
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "custom",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
|
||||
|
||||
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
|
||||
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
|
||||
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
|
||||
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services[0].Domain = "a.example.com"
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "bapp",
|
||||
Domain: "b.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
|
||||
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
|
||||
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// Service domain is outside the cluster and no account.Domains entry
|
||||
// covers it: there is no apex that would intercept the query, so the
|
||||
// service must be skipped rather than emit an unmatchable record.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// The registered custom domain matches the service FQDN by suffix but
|
||||
// targets a different cluster than the service's ProxyCluster. It must
|
||||
// be ignored, leaving no apex to intercept the query — otherwise the
|
||||
// zone would point at this cluster's proxy peers under a domain owned
|
||||
// by a different cluster.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
@@ -254,3 +362,72 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
privateService := func(id, domain string) *service.Service {
|
||||
return &service.Service{
|
||||
ID: id,
|
||||
AccountID: "acct-1",
|
||||
Name: id,
|
||||
Domain: domain,
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
}
|
||||
}
|
||||
publicService := func(id, domain string) *service.Service {
|
||||
s := privateService(id, domain)
|
||||
s.Private = false
|
||||
return s
|
||||
}
|
||||
|
||||
account.Services = []*service.Service{
|
||||
// 3 private services under the cluster suffix.
|
||||
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
|
||||
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
|
||||
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
|
||||
// 4 private services under the custom domain suffix.
|
||||
privateService("custom-1", "custom1.example.com"),
|
||||
privateService("custom-2", "custom2.example.com"),
|
||||
privateService("custom-3", "custom3.example.com"),
|
||||
privateService("custom-4", "custom4.example.com"),
|
||||
// 2 public services, one per suffix, must not surface.
|
||||
publicService("public-cluster", "public.eu.proxy.netbird.io"),
|
||||
publicService("public-custom", "public.example.com"),
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
|
||||
|
||||
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
|
||||
clusterNames := recordNames(cluster)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
|
||||
clusterNames,
|
||||
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
|
||||
customNames := recordNames(custom)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
|
||||
customNames,
|
||||
"only the 4 private custom services surface in the custom zone (public one excluded)")
|
||||
}
|
||||
|
||||
// recordNames returns the record names of a zone for order-independent assertions.
|
||||
func recordNames(zone nbdns.CustomZone) []string {
|
||||
names := make([]string, 0, len(zone.Records))
|
||||
for _, r := range zone.Records {
|
||||
names = append(names, r.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
@@ -249,6 +249,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
|
||||
@@ -28,6 +28,10 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
const clientStopTimeout = 30 * time.Second
|
||||
|
||||
const createProxyPeerTimeout = 30 * time.Second
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey string
|
||||
|
||||
@@ -162,6 +166,7 @@ type NetBird struct {
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
lifecycleMu sync.Map
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
@@ -177,6 +182,10 @@ type NetBird struct {
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
|
||||
OnAddPeer func(d time.Duration, err error)
|
||||
|
||||
// startClient runs the post-create client startup. Nil uses runClientStartup;
|
||||
// tests override it to avoid a real embed client.Start.
|
||||
startClient func(accountID types.AccountID, client *embed.Client)
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
@@ -200,31 +209,20 @@ type skipTLSVerifyContextKey struct{}
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
// Use a background context, not the caller's: the management
|
||||
// connection notification must land even if the request /
|
||||
// stream that triggered this registration is cancelled.
|
||||
// Mirrors the async runClientStartup path.
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -234,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -246,17 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip. runClientStartup uses its
|
||||
// own background context so the caller's request-scoped ctx can't
|
||||
// cancel the inbound bring-up.
|
||||
go n.runClientStartup(accountID, entry.client)
|
||||
transferred = true
|
||||
go func() {
|
||||
defer lifecycle.Unlock()
|
||||
n.startClientStartup(accountID, entry.client)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
|
||||
if n.startClient != nil {
|
||||
n.startClient(accountID, client)
|
||||
return
|
||||
}
|
||||
n.runClientStartup(accountID, client)
|
||||
}
|
||||
|
||||
// registerExistingClient registers the service against an already-present
|
||||
// client for the account and returns true when it did. It notifies management
|
||||
// of the new service when the client is already started.
|
||||
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.Unlock()
|
||||
return false
|
||||
}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// accountLifecycle returns the per-account lifecycle mutex, serialising client
|
||||
// creation against teardown so a slow client.Stop cannot race a new
|
||||
// client.Start for the same account, without blocking clientsMux.
|
||||
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
|
||||
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
|
||||
return mu.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
// and creates an embedded NetBird client. Must be called with the account's
|
||||
// lifecycle mutex held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -276,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
|
||||
defer cancel()
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
@@ -444,6 +491,15 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -466,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -490,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
transferred = true
|
||||
go n.stopClientLocked(accountID, lifecycle, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopClientLocked releases a client's resources off the caller's goroutine so a
|
||||
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
|
||||
// synchronously). It unlocks lifecycle when done so a new client.Start for the
|
||||
// same account waits for this teardown.
|
||||
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
|
||||
defer lifecycle.Unlock()
|
||||
|
||||
if entry.inbound != nil && n.stopHandler != nil {
|
||||
n.stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
if entry.transport != nil {
|
||||
entry.transport.CloseIdleConnections()
|
||||
}
|
||||
if entry.insecureTransport != nil {
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
}
|
||||
if entry.client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -22,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
|
||||
// tests can detect AddPeer reaching client creation.
|
||||
type signalMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
m.once.Do(func() { close(m.entered) })
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
@@ -52,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, nil, &mockMgmtClient{})
|
||||
// Skip the real embed client.Start, which would hang against the unreachable
|
||||
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
return nb
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
@@ -288,6 +305,7 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first service — creates a new client entry.
|
||||
@@ -372,6 +390,117 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
|
||||
// stall: RemovePeer must return promptly even when the client teardown blocks,
|
||||
// because teardown runs off the caller's goroutine. The receive loop calls
|
||||
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
|
||||
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
|
||||
accountID := types.AccountID("acct-async-teardown")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
teardownEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
close(teardownEntered)
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-teardownEntered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("teardown never ran")
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
}
|
||||
|
||||
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
|
||||
// new client bringup behind an in-flight teardown for the same account, so a
|
||||
// slow client.Stop can never race a new client.Start for that account.
|
||||
//
|
||||
// It targets the handoff race specifically: AddPeer is launched immediately
|
||||
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
|
||||
// This only passes if RemovePeer acquires the lifecycle lock synchronously
|
||||
// (before returning) and hands it to the teardown goroutine — if the goroutine
|
||||
// acquired the lock itself, AddPeer could win the lock in this window and start
|
||||
// a replacement client while the old teardown is still pending.
|
||||
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
|
||||
accountID := types.AccountID("acct-serialize")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
addEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
// Block teardown until released. If AddPeer ever reaches createClientEntry
|
||||
// (signalled via the mgmt client below) while we hold the lock, the lock
|
||||
// failed to serialise and the test fails before we release.
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
|
||||
// AddPeer got past the lifecycle lock and into client creation.
|
||||
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
|
||||
|
||||
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
|
||||
|
||||
// Launch AddPeer with NO synchronisation against the teardown goroutine.
|
||||
addReturned := make(chan struct{})
|
||||
go func() {
|
||||
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
|
||||
close(addReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-addEntered:
|
||||
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
|
||||
case <-addReturned:
|
||||
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
select {
|
||||
case <-addReturned:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
|
||||
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
|
||||
// a fresh context.Background() rather than inheriting the AddPeer
|
||||
|
||||
@@ -114,6 +114,10 @@ type Config struct {
|
||||
MaxDialTimeout time.Duration
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// being applied before the receive loop reconnects to resync. Zero falls
|
||||
// back to the internal default.
|
||||
MappingBatchWatchdog time.Duration
|
||||
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files.
|
||||
GeoDataDir string
|
||||
@@ -164,6 +168,7 @@ func New(ctx context.Context, cfg Config) *Server {
|
||||
Private: cfg.Private,
|
||||
MaxDialTimeout: cfg.MaxDialTimeout,
|
||||
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
|
||||
GeoDataDir: cfg.GeoDataDir,
|
||||
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
|
||||
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
|
||||
|
||||
282
proxy/mapping_stall_test.go
Normal file
282
proxy/mapping_stall_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// blockingMgmtClient implements roundtrip's managementClient interface.
|
||||
// CreateProxyPeer parks until release is closed, signalling entry on entered.
|
||||
// This reproduces the confirmed real-world stall: createClientEntry calls
|
||||
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
|
||||
// receive loop calls that path synchronously inside processMappings.
|
||||
type blockingMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
b.once.Do(func() { close(b.entered) })
|
||||
// Park until the caller's context is cancelled. In production this ctx is
|
||||
// the gRPC mapping-stream context with no per-call timeout, so a slow or
|
||||
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
|
||||
// pre-seeded list of messages, then records how many times Recv advanced. It
|
||||
// lets the test observe whether the single-threaded receive loop ever gets
|
||||
// past the first (blocking) batch to pull the second message.
|
||||
type gatedMappingStream struct {
|
||||
grpc.ClientStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
idx int32
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
i := int(atomic.LoadInt32(&g.idx))
|
||||
if i >= len(g.messages) {
|
||||
// Block instead of returning EOF so the loop doesn't exit; we only
|
||||
// care whether the loop ever reaches this second Recv at all.
|
||||
select {}
|
||||
}
|
||||
msg := g.messages[i]
|
||||
atomic.AddInt32(&g.idx, 1)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
|
||||
|
||||
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
|
||||
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
|
||||
func (g *gatedMappingStream) CloseSend() error { return nil }
|
||||
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
|
||||
func (g *gatedMappingStream) SendMsg(any) error { return nil }
|
||||
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// noopNotifier satisfies roundtrip's statusNotifier interface.
|
||||
type noopNotifier struct{}
|
||||
|
||||
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
|
||||
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
|
||||
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
|
||||
// compile time; none of those methods are called by this test.
|
||||
type noopProxyClient struct {
|
||||
proto.ProxyServiceClient
|
||||
}
|
||||
|
||||
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
|
||||
// mapping receive loop processes batches strictly serially, so when applying
|
||||
// one batch blocks (here: createClientEntry parked on a synchronous
|
||||
// CreateProxyPeer call, exactly as observed in production), the loop never
|
||||
// advances to Recv the next batch. Management can keep sending updates onto
|
||||
// the stream with no error and no channel overflow, yet the proxy applies
|
||||
// nothing further — it is stuck.
|
||||
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
mgmt := &blockingMgmtClient{
|
||||
entered: make(chan struct{}),
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird(
|
||||
context.Background(),
|
||||
"proxy-test",
|
||||
"proxy.example.com",
|
||||
roundtrip.ClientConfig{},
|
||||
logger,
|
||||
noopNotifier{},
|
||||
mgmt,
|
||||
)
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
netbird: nb,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// First batch: a CREATED mapping for a brand-new account. addMapping ->
|
||||
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
|
||||
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
|
||||
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
|
||||
// no routers/auth need wiring. The second batch exists only to detect
|
||||
// whether the loop ever advances past the blocked first batch.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
AuthToken: "token-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-2",
|
||||
AccountId: "acct-2",
|
||||
AuthToken: "token-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
|
||||
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
|
||||
// avoiding any dependency on collaborators this test deliberately leaves
|
||||
// nil. The deadlock is fully proven before this fires.
|
||||
t.Cleanup(cancel)
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
// The loop must reach the blocking apply for the first batch.
|
||||
select {
|
||||
case <-mgmt.entered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
|
||||
// single-threaded loop cannot advance. The second batch is never pulled,
|
||||
// even though it is already available on the stream. Give it ample time.
|
||||
// deliveredCount is atomic; syncDone is intentionally not read here because
|
||||
// the loop goroutine owns it (reading it from the test would race).
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged in apply")
|
||||
default:
|
||||
// Still wedged, as expected.
|
||||
}
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
|
||||
// path observed in production: a mapping remove tears down the account's last
|
||||
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
|
||||
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
|
||||
// threaded, a blocked remove wedges the loop: no further mapping updates of any
|
||||
// kind (create/modify/remove) are applied, while management keeps sending them
|
||||
// successfully (no send error, no channel-full). Matches the reported symptom:
|
||||
// the last log line is a remove that stops a client, then silence.
|
||||
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
enteredRemove := make(chan struct{})
|
||||
blockRemove := make(chan struct{})
|
||||
var once sync.Once
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
// Stand in for netbird.RemovePeer -> client.Stop hanging on
|
||||
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
|
||||
// blocks; later removes return immediately so the recovery assertion
|
||||
// can observe the loop advancing.
|
||||
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
|
||||
first := false
|
||||
once.Do(func() {
|
||||
first = true
|
||||
close(enteredRemove)
|
||||
})
|
||||
if !first {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-blockRemove:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
|
||||
// that must never be applied while the remove is wedged.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-enteredRemove:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached the blocking remove for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
|
||||
// syncDone is owned by the loop goroutine, so it is not read here.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged on the remove")
|
||||
default:
|
||||
}
|
||||
|
||||
// Unblock and confirm the wedge was solely the blocked remove: the loop
|
||||
// then advances and consumes the next batch.
|
||||
close(blockRemove)
|
||||
assert.Eventually(t, func() bool {
|
||||
return stream.deliveredCount() >= 2
|
||||
}, 2*time.Second, 5*time.Millisecond,
|
||||
"once the remove unblocks, the loop must advance and consume the next batch")
|
||||
}
|
||||
@@ -118,6 +118,9 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// removePeer defaults to netbird.RemovePeer; overridable in tests.
|
||||
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
@@ -227,6 +230,10 @@ type Server struct {
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// in processMappings before the receive loop reconnects to resync.
|
||||
// Zero uses defaultMappingBatchWatchdog.
|
||||
MappingBatchWatchdog time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
@@ -1172,24 +1179,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
connected := false
|
||||
onConnected := func() { connected = true }
|
||||
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
// Reset backoff only when a stream actually connected, so immediate
|
||||
// connect failures still back off instead of spinning.
|
||||
if connected {
|
||||
bo.Reset()
|
||||
}
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
@@ -1221,7 +1234,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
@@ -1234,6 +1247,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1242,7 +1256,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
@@ -1263,6 +1277,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1307,7 +1322,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
@@ -1391,7 +1408,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
@@ -1456,6 +1475,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
return c
|
||||
}
|
||||
|
||||
const defaultMappingBatchWatchdog = 2 * time.Minute
|
||||
|
||||
// mappingBatchWatchdog returns the configured batch watchdog or the default.
|
||||
func (s *Server) mappingBatchWatchdog() time.Duration {
|
||||
if s.MappingBatchWatchdog > 0 {
|
||||
return s.MappingBatchWatchdog
|
||||
}
|
||||
return defaultMappingBatchWatchdog
|
||||
}
|
||||
|
||||
// processMappingsGuarded applies a batch under a watchdog, returning an error
|
||||
// if processing exceeds the watchdog so the caller reconnects and resyncs
|
||||
// instead of wedging silently.
|
||||
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
|
||||
batchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.processMappings(batchCtx, mappings)
|
||||
}()
|
||||
|
||||
watchdog := s.mappingBatchWatchdog()
|
||||
timer := time.NewTimer(watchdog)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
|
||||
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
@@ -1951,7 +2008,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
svcKey := s.serviceKeyForMapping(mapping)
|
||||
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
|
||||
removePeer := s.removePeer
|
||||
if removePeer == nil {
|
||||
removePeer = s.netbird.RemovePeer
|
||||
}
|
||||
if err := removePeer(ctx, accountID, svcKey); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": mapping.GetId(),
|
||||
|
||||
Reference in New Issue
Block a user