mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-09 17:39:57 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a40028092d | ||
|
|
13200265d8 | ||
|
|
ed7a9363aa | ||
|
|
d56859dc5d | ||
|
|
367d37050b | ||
|
|
106527182f | ||
|
|
8e1d5b78c2 | ||
|
|
d3b63c6be9 | ||
|
|
60d2fa08b0 | ||
|
|
1e7b16db0a | ||
|
|
b377d99933 |
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
run: bash -x release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
run: bash -x release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -3,6 +3,7 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -749,11 +750,17 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -806,6 +806,8 @@ func (g *BundleGenerator) addSyncResponse() error {
|
||||
AllowPartial: true,
|
||||
}
|
||||
|
||||
g.maskSecrets()
|
||||
|
||||
jsonBytes, err := options.Marshal(g.syncResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate json: %w", err)
|
||||
@@ -818,6 +820,27 @@ func (g *BundleGenerator) addSyncResponse() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) maskSecrets() {
|
||||
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if g.syncResponse.NetbirdConfig.Flow != nil {
|
||||
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
|
||||
|
||||
}
|
||||
|
||||
if g.syncResponse.NetbirdConfig.Relay != nil {
|
||||
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
|
||||
}
|
||||
|
||||
for i := range g.syncResponse.NetbirdConfig.Turns {
|
||||
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
|
||||
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStateFile() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path := sm.GetStatePath()
|
||||
|
||||
@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
@@ -700,6 +700,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
71
client/internal/routemanager/selector_management_test.go
Normal file
71
client/internal/routemanager/selector_management_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,6 +116,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -99,6 +99,9 @@ func addFields(entry *logrus.Entry) {
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
|
||||
entry.Data[context.UserAgentKey] = ctxUserAgent
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
|
||||
@@ -19,6 +19,46 @@ readonly MSG_SEPARATOR="=========================================="
|
||||
# Utility Functions
|
||||
############################################
|
||||
|
||||
check_docker_sock_perms() {
|
||||
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
|
||||
sock="${sock#unix://}"
|
||||
|
||||
if [[ ! -S "$sock" ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
|
||||
local group
|
||||
if [[ "${OSTYPE}" == "darwin"* ]]; then
|
||||
group="$(stat -f '%Sg' "$sock")"
|
||||
else
|
||||
group="$(stat -c '%G' "$sock")"
|
||||
fi
|
||||
|
||||
echo "Cannot access Docker socket: $sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo "Socket permissions:" > /dev/stderr
|
||||
ls -l "$sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
|
||||
if [[ "$group" == "docker" ]]; then
|
||||
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
|
||||
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
|
||||
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
|
||||
echo " newgrp $group" > /dev/stderr
|
||||
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
|
||||
else
|
||||
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
|
||||
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
|
||||
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
|
||||
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
|
||||
fi
|
||||
|
||||
exit 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null
|
||||
then
|
||||
@@ -581,12 +621,15 @@ start_services_and_show_instructions() {
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
# Check if docker compose is installed using check_docker_compose function
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_docker_sock_perms
|
||||
|
||||
initialize_default_values
|
||||
configure_domain
|
||||
configure_reverse_proxy
|
||||
|
||||
check_jq
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
check_existing_installation
|
||||
generate_configuration_files
|
||||
|
||||
@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -978,6 +980,7 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
AccessRestrictions: m.AccessRestrictions,
|
||||
Private: m.Private,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
88
management/internals/shared/grpc/proxy_clone_test.go
Normal file
88
management/internals/shared/grpc/proxy_clone_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
|
||||
// copy from the source, since callers assign it individually after cloning.
|
||||
const authTokenField = "AuthToken"
|
||||
|
||||
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
|
||||
// ProxyMapping with a non-zero value and verifies the clone carries each one
|
||||
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
|
||||
// without updating shallowCloneMapping fails this test.
|
||||
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
|
||||
src := &proto.ProxyMapping{}
|
||||
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
|
||||
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
|
||||
|
||||
clone := shallowCloneMapping(src)
|
||||
require.NotNil(t, clone, "clone must not be nil")
|
||||
|
||||
srcVal := reflect.ValueOf(src).Elem()
|
||||
cloneVal := reflect.ValueOf(clone).Elem()
|
||||
|
||||
for _, name := range populated {
|
||||
srcField := srcVal.FieldByName(name).Interface()
|
||||
cloneField := cloneVal.FieldByName(name).Interface()
|
||||
|
||||
if name == authTokenField {
|
||||
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateExportedFields sets a non-zero value on every settable exported field
|
||||
// of the struct and returns their names.
|
||||
func populateExportedFields(t *testing.T, v reflect.Value) []string {
|
||||
t.Helper()
|
||||
|
||||
var names []string
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
structField := typ.Field(i)
|
||||
|
||||
if structField.PkgPath != "" || !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
setNonZero(t, field, structField.Name)
|
||||
names = append(names, structField.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value based on the field kind.
|
||||
func setNonZero(t *testing.T, field reflect.Value, name string) {
|
||||
t.Helper()
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero-" + name)
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Ptr:
|
||||
field.Set(reflect.New(field.Type().Elem()))
|
||||
case reflect.Slice:
|
||||
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
|
||||
case reflect.Map:
|
||||
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ const (
|
||||
RoleKey = nbcontext.RoleKey
|
||||
UserIDKey = nbcontext.UserIDKey
|
||||
PeerIDKey = nbcontext.PeerIDKey
|
||||
UserAgentKey = nbcontext.UserAgentKey
|
||||
)
|
||||
|
||||
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
|
||||
|
||||
@@ -1216,6 +1216,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("NetworkResources").
|
||||
Preload("Onboarding").
|
||||
Preload("Services.Targets").
|
||||
Preload("Domains").
|
||||
Take(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
@@ -1302,7 +1303,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12)
|
||||
errChan := make(chan error, 16)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1403,6 +1404,17 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
account.Services = services
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
domains, err := s.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Domains = domains
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +23,63 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
|
||||
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
|
||||
// private service's DNS zone; without the preload the relation is empty and the
|
||||
// service is silently skipped, so a custom domain never resolves on clients.
|
||||
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
|
||||
// paths: it persists two custom domains and asserts the relation comes back
|
||||
// populated, which SynthesizePrivateServiceZones relies on.
|
||||
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := "acct-custom-domains"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
|
||||
|
||||
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
|
||||
require.NoError(t, err, "creating the first custom domain must succeed")
|
||||
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
|
||||
require.NoError(t, err, "creating the second custom domain must succeed")
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
|
||||
|
||||
byDomain := map[string]string{}
|
||||
for _, d := range account.Domains {
|
||||
require.NotNil(t, d)
|
||||
byDomain[d.Domain] = d.TargetCluster
|
||||
}
|
||||
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
|
||||
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
|
||||
}
|
||||
|
||||
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
|
||||
// all fields and nested objects from the database, including deeply nested structures.
|
||||
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {
|
||||
|
||||
@@ -21,6 +21,8 @@ const (
|
||||
httpRequestCounterPrefix = "management.http.request.counter"
|
||||
httpResponseCounterPrefix = "management.http.response.counter"
|
||||
httpRequestDurationPrefix = "management.http.request.duration.ms"
|
||||
|
||||
RequestIDHeader = "X-Request-Id"
|
||||
)
|
||||
|
||||
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
|
||||
@@ -172,6 +174,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
reqID := xid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
|
||||
|
||||
rw.Header().Set(RequestIDHeader, reqID)
|
||||
|
||||
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
zonesByApex := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
@@ -290,19 +290,24 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
serviceDomainZone := a.privateServiceDomainZone(svc)
|
||||
if serviceDomainZone == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByApex[serviceDomainZone]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the cluster apex.
|
||||
// other name under the zone apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Domain: dns.Fqdn(serviceDomainZone),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
zonesByApex[serviceDomainZone] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
@@ -340,8 +345,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
|
||||
for _, zone := range zonesByApex {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -357,6 +362,33 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
return out
|
||||
}
|
||||
|
||||
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
|
||||
// looking at the proxy cluster domain then the custom domains.
|
||||
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
|
||||
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
|
||||
return svc.ProxyCluster
|
||||
}
|
||||
|
||||
// Longest matching custom domain wins
|
||||
zoneName := ""
|
||||
for _, d := range a.Domains {
|
||||
if d == nil || d.TargetCluster != svc.ProxyCluster {
|
||||
continue
|
||||
}
|
||||
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
|
||||
zoneName = d.Domain
|
||||
}
|
||||
}
|
||||
return zoneName
|
||||
}
|
||||
|
||||
func domainFromSuffix(domain, suffix string) bool {
|
||||
if suffix == "" {
|
||||
return false
|
||||
}
|
||||
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -234,6 +235,113 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
|
||||
// is the cluster serving it, and account.Domains holds the registered
|
||||
// custom domain. The synth zone apex must be the registered domain,
|
||||
// not the cluster, or the client's match-only zone never intercepts
|
||||
// the query.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
|
||||
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "custom",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
|
||||
|
||||
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
|
||||
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
|
||||
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
|
||||
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services[0].Domain = "a.example.com"
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "bapp",
|
||||
Domain: "b.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
|
||||
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
|
||||
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// Service domain is outside the cluster and no account.Domains entry
|
||||
// covers it: there is no apex that would intercept the query, so the
|
||||
// service must be skipped rather than emit an unmatchable record.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// The registered custom domain matches the service FQDN by suffix but
|
||||
// targets a different cluster than the service's ProxyCluster. It must
|
||||
// be ignored, leaving no apex to intercept the query — otherwise the
|
||||
// zone would point at this cluster's proxy peers under a domain owned
|
||||
// by a different cluster.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
@@ -254,3 +362,72 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
privateService := func(id, domain string) *service.Service {
|
||||
return &service.Service{
|
||||
ID: id,
|
||||
AccountID: "acct-1",
|
||||
Name: id,
|
||||
Domain: domain,
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
}
|
||||
}
|
||||
publicService := func(id, domain string) *service.Service {
|
||||
s := privateService(id, domain)
|
||||
s.Private = false
|
||||
return s
|
||||
}
|
||||
|
||||
account.Services = []*service.Service{
|
||||
// 3 private services under the cluster suffix.
|
||||
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
|
||||
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
|
||||
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
|
||||
// 4 private services under the custom domain suffix.
|
||||
privateService("custom-1", "custom1.example.com"),
|
||||
privateService("custom-2", "custom2.example.com"),
|
||||
privateService("custom-3", "custom3.example.com"),
|
||||
privateService("custom-4", "custom4.example.com"),
|
||||
// 2 public services, one per suffix, must not surface.
|
||||
publicService("public-cluster", "public.eu.proxy.netbird.io"),
|
||||
publicService("public-custom", "public.example.com"),
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
|
||||
|
||||
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
|
||||
clusterNames := recordNames(cluster)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
|
||||
clusterNames,
|
||||
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
|
||||
customNames := recordNames(custom)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
|
||||
customNames,
|
||||
"only the 4 private custom services surface in the custom zone (public one excluded)")
|
||||
}
|
||||
|
||||
// recordNames returns the record names of a zone for order-independent assertions.
|
||||
func recordNames(zone nbdns.CustomZone) []string {
|
||||
names := make([]string, 0, len(zone.Records))
|
||||
for _, r := range zone.Records {
|
||||
names = append(names, r.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
|
||||
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
|
||||
|
||||
rules := []*RouteFirewallRule{&rule}
|
||||
|
||||
if includeIPv6 && r.IsDynamic() {
|
||||
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
|
||||
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
if isDefaultV4 {
|
||||
ruleV6.Destination = "::/0"
|
||||
ruleV6.RouteID = r.ID + "-v6-default"
|
||||
}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
|
||||
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
|
||||
}
|
||||
|
||||
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
|
||||
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
|
||||
// (::/0 source and destination) for peers that support IPv6, mirroring the route
|
||||
// the client installs. Without it, IPv6 traffic is routed to the exit node but
|
||||
// dropped at the forward chain.
|
||||
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
|
||||
routingPeerID := "peer-5"
|
||||
routingPeer := account.Peers[routingPeerID]
|
||||
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
|
||||
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
|
||||
account.Routes["route-exit"] = &route.Route{
|
||||
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
PeerID: routingPeerID, Peer: routingPeer.Key,
|
||||
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
|
||||
AccessControlGroups: []string{},
|
||||
AccountID: "test-account",
|
||||
}
|
||||
|
||||
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, rfr := range nm.RoutesFirewallRules {
|
||||
switch rfr.Destination {
|
||||
case "0.0.0.0/0":
|
||||
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
|
||||
hasV4 = true
|
||||
}
|
||||
case "::/0":
|
||||
if slices.Contains(rfr.SourceRanges, "::/0") {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
|
||||
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// 15. MULTIPLE ROUTERS PER NETWORK
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -249,6 +249,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
|
||||
@@ -28,6 +28,10 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
const clientStopTimeout = 30 * time.Second
|
||||
|
||||
const createProxyPeerTimeout = 30 * time.Second
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey string
|
||||
|
||||
@@ -162,6 +166,7 @@ type NetBird struct {
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
lifecycleMu sync.Map
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
@@ -177,6 +182,10 @@ type NetBird struct {
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
|
||||
OnAddPeer func(d time.Duration, err error)
|
||||
|
||||
// startClient runs the post-create client startup. Nil uses runClientStartup;
|
||||
// tests override it to avoid a real embed client.Start.
|
||||
startClient func(accountID types.AccountID, client *embed.Client)
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
@@ -200,31 +209,20 @@ type skipTLSVerifyContextKey struct{}
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
// Use a background context, not the caller's: the management
|
||||
// connection notification must land even if the request /
|
||||
// stream that triggered this registration is cancelled.
|
||||
// Mirrors the async runClientStartup path.
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -234,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -246,17 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip. runClientStartup uses its
|
||||
// own background context so the caller's request-scoped ctx can't
|
||||
// cancel the inbound bring-up.
|
||||
go n.runClientStartup(accountID, entry.client)
|
||||
transferred = true
|
||||
go func() {
|
||||
defer lifecycle.Unlock()
|
||||
n.startClientStartup(accountID, entry.client)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
|
||||
if n.startClient != nil {
|
||||
n.startClient(accountID, client)
|
||||
return
|
||||
}
|
||||
n.runClientStartup(accountID, client)
|
||||
}
|
||||
|
||||
// registerExistingClient registers the service against an already-present
|
||||
// client for the account and returns true when it did. It notifies management
|
||||
// of the new service when the client is already started.
|
||||
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.Unlock()
|
||||
return false
|
||||
}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// accountLifecycle returns the per-account lifecycle mutex, serialising client
|
||||
// creation against teardown so a slow client.Stop cannot race a new
|
||||
// client.Start for the same account, without blocking clientsMux.
|
||||
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
|
||||
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
|
||||
return mu.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
// and creates an embedded NetBird client. Must be called with the account's
|
||||
// lifecycle mutex held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -276,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
|
||||
defer cancel()
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
@@ -444,6 +491,15 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -466,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -490,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
transferred = true
|
||||
go n.stopClientLocked(accountID, lifecycle, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopClientLocked releases a client's resources off the caller's goroutine so a
|
||||
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
|
||||
// synchronously). It unlocks lifecycle when done so a new client.Start for the
|
||||
// same account waits for this teardown.
|
||||
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
|
||||
defer lifecycle.Unlock()
|
||||
|
||||
if entry.inbound != nil && n.stopHandler != nil {
|
||||
n.stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
if entry.transport != nil {
|
||||
entry.transport.CloseIdleConnections()
|
||||
}
|
||||
if entry.insecureTransport != nil {
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
}
|
||||
if entry.client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -22,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
|
||||
// tests can detect AddPeer reaching client creation.
|
||||
type signalMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
m.once.Do(func() { close(m.entered) })
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
@@ -52,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, nil, &mockMgmtClient{})
|
||||
// Skip the real embed client.Start, which would hang against the unreachable
|
||||
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
return nb
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
@@ -288,6 +305,7 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first service — creates a new client entry.
|
||||
@@ -372,6 +390,117 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
|
||||
// stall: RemovePeer must return promptly even when the client teardown blocks,
|
||||
// because teardown runs off the caller's goroutine. The receive loop calls
|
||||
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
|
||||
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
|
||||
accountID := types.AccountID("acct-async-teardown")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
teardownEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
close(teardownEntered)
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-teardownEntered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("teardown never ran")
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
}
|
||||
|
||||
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
|
||||
// new client bringup behind an in-flight teardown for the same account, so a
|
||||
// slow client.Stop can never race a new client.Start for that account.
|
||||
//
|
||||
// It targets the handoff race specifically: AddPeer is launched immediately
|
||||
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
|
||||
// This only passes if RemovePeer acquires the lifecycle lock synchronously
|
||||
// (before returning) and hands it to the teardown goroutine — if the goroutine
|
||||
// acquired the lock itself, AddPeer could win the lock in this window and start
|
||||
// a replacement client while the old teardown is still pending.
|
||||
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
|
||||
accountID := types.AccountID("acct-serialize")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
addEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
// Block teardown until released. If AddPeer ever reaches createClientEntry
|
||||
// (signalled via the mgmt client below) while we hold the lock, the lock
|
||||
// failed to serialise and the test fails before we release.
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
|
||||
// AddPeer got past the lifecycle lock and into client creation.
|
||||
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
|
||||
|
||||
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
|
||||
|
||||
// Launch AddPeer with NO synchronisation against the teardown goroutine.
|
||||
addReturned := make(chan struct{})
|
||||
go func() {
|
||||
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
|
||||
close(addReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-addEntered:
|
||||
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
|
||||
case <-addReturned:
|
||||
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
select {
|
||||
case <-addReturned:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
|
||||
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
|
||||
// a fresh context.Background() rather than inheriting the AddPeer
|
||||
|
||||
@@ -114,6 +114,10 @@ type Config struct {
|
||||
MaxDialTimeout time.Duration
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// being applied before the receive loop reconnects to resync. Zero falls
|
||||
// back to the internal default.
|
||||
MappingBatchWatchdog time.Duration
|
||||
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files.
|
||||
GeoDataDir string
|
||||
@@ -164,6 +168,7 @@ func New(ctx context.Context, cfg Config) *Server {
|
||||
Private: cfg.Private,
|
||||
MaxDialTimeout: cfg.MaxDialTimeout,
|
||||
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
|
||||
GeoDataDir: cfg.GeoDataDir,
|
||||
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
|
||||
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
|
||||
|
||||
282
proxy/mapping_stall_test.go
Normal file
282
proxy/mapping_stall_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// blockingMgmtClient implements roundtrip's managementClient interface.
|
||||
// CreateProxyPeer parks until release is closed, signalling entry on entered.
|
||||
// This reproduces the confirmed real-world stall: createClientEntry calls
|
||||
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
|
||||
// receive loop calls that path synchronously inside processMappings.
|
||||
type blockingMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
b.once.Do(func() { close(b.entered) })
|
||||
// Park until the caller's context is cancelled. In production this ctx is
|
||||
// the gRPC mapping-stream context with no per-call timeout, so a slow or
|
||||
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
|
||||
// pre-seeded list of messages, then records how many times Recv advanced. It
|
||||
// lets the test observe whether the single-threaded receive loop ever gets
|
||||
// past the first (blocking) batch to pull the second message.
|
||||
type gatedMappingStream struct {
|
||||
grpc.ClientStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
idx int32
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
i := int(atomic.LoadInt32(&g.idx))
|
||||
if i >= len(g.messages) {
|
||||
// Block instead of returning EOF so the loop doesn't exit; we only
|
||||
// care whether the loop ever reaches this second Recv at all.
|
||||
select {}
|
||||
}
|
||||
msg := g.messages[i]
|
||||
atomic.AddInt32(&g.idx, 1)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
|
||||
|
||||
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
|
||||
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
|
||||
func (g *gatedMappingStream) CloseSend() error { return nil }
|
||||
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
|
||||
func (g *gatedMappingStream) SendMsg(any) error { return nil }
|
||||
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// noopNotifier satisfies roundtrip's statusNotifier interface.
|
||||
type noopNotifier struct{}
|
||||
|
||||
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
|
||||
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
|
||||
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
|
||||
// compile time; none of those methods are called by this test.
|
||||
type noopProxyClient struct {
|
||||
proto.ProxyServiceClient
|
||||
}
|
||||
|
||||
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
|
||||
// mapping receive loop processes batches strictly serially, so when applying
|
||||
// one batch blocks (here: createClientEntry parked on a synchronous
|
||||
// CreateProxyPeer call, exactly as observed in production), the loop never
|
||||
// advances to Recv the next batch. Management can keep sending updates onto
|
||||
// the stream with no error and no channel overflow, yet the proxy applies
|
||||
// nothing further — it is stuck.
|
||||
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
mgmt := &blockingMgmtClient{
|
||||
entered: make(chan struct{}),
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird(
|
||||
context.Background(),
|
||||
"proxy-test",
|
||||
"proxy.example.com",
|
||||
roundtrip.ClientConfig{},
|
||||
logger,
|
||||
noopNotifier{},
|
||||
mgmt,
|
||||
)
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
netbird: nb,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// First batch: a CREATED mapping for a brand-new account. addMapping ->
|
||||
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
|
||||
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
|
||||
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
|
||||
// no routers/auth need wiring. The second batch exists only to detect
|
||||
// whether the loop ever advances past the blocked first batch.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
AuthToken: "token-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-2",
|
||||
AccountId: "acct-2",
|
||||
AuthToken: "token-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
|
||||
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
|
||||
// avoiding any dependency on collaborators this test deliberately leaves
|
||||
// nil. The deadlock is fully proven before this fires.
|
||||
t.Cleanup(cancel)
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
// The loop must reach the blocking apply for the first batch.
|
||||
select {
|
||||
case <-mgmt.entered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
|
||||
// single-threaded loop cannot advance. The second batch is never pulled,
|
||||
// even though it is already available on the stream. Give it ample time.
|
||||
// deliveredCount is atomic; syncDone is intentionally not read here because
|
||||
// the loop goroutine owns it (reading it from the test would race).
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged in apply")
|
||||
default:
|
||||
// Still wedged, as expected.
|
||||
}
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
|
||||
// path observed in production: a mapping remove tears down the account's last
|
||||
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
|
||||
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
|
||||
// threaded, a blocked remove wedges the loop: no further mapping updates of any
|
||||
// kind (create/modify/remove) are applied, while management keeps sending them
|
||||
// successfully (no send error, no channel-full). Matches the reported symptom:
|
||||
// the last log line is a remove that stops a client, then silence.
|
||||
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
enteredRemove := make(chan struct{})
|
||||
blockRemove := make(chan struct{})
|
||||
var once sync.Once
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
// Stand in for netbird.RemovePeer -> client.Stop hanging on
|
||||
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
|
||||
// blocks; later removes return immediately so the recovery assertion
|
||||
// can observe the loop advancing.
|
||||
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
|
||||
first := false
|
||||
once.Do(func() {
|
||||
first = true
|
||||
close(enteredRemove)
|
||||
})
|
||||
if !first {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-blockRemove:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
|
||||
// that must never be applied while the remove is wedged.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-enteredRemove:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached the blocking remove for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
|
||||
// syncDone is owned by the loop goroutine, so it is not read here.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged on the remove")
|
||||
default:
|
||||
}
|
||||
|
||||
// Unblock and confirm the wedge was solely the blocked remove: the loop
|
||||
// then advances and consumes the next batch.
|
||||
close(blockRemove)
|
||||
assert.Eventually(t, func() bool {
|
||||
return stream.deliveredCount() >= 2
|
||||
}, 2*time.Second, 5*time.Millisecond,
|
||||
"once the remove unblocks, the loop must advance and consume the next batch")
|
||||
}
|
||||
@@ -118,6 +118,9 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// removePeer defaults to netbird.RemovePeer; overridable in tests.
|
||||
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
@@ -227,6 +230,10 @@ type Server struct {
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// in processMappings before the receive loop reconnects to resync.
|
||||
// Zero uses defaultMappingBatchWatchdog.
|
||||
MappingBatchWatchdog time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
@@ -1172,24 +1179,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
connected := false
|
||||
onConnected := func() { connected = true }
|
||||
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
// Reset backoff only when a stream actually connected, so immediate
|
||||
// connect failures still back off instead of spinning.
|
||||
if connected {
|
||||
bo.Reset()
|
||||
}
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
@@ -1221,7 +1234,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
@@ -1234,6 +1247,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1242,7 +1256,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
@@ -1263,6 +1277,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1307,7 +1322,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
@@ -1391,7 +1408,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
@@ -1456,6 +1475,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
return c
|
||||
}
|
||||
|
||||
const defaultMappingBatchWatchdog = 2 * time.Minute
|
||||
|
||||
// mappingBatchWatchdog returns the configured batch watchdog or the default.
|
||||
func (s *Server) mappingBatchWatchdog() time.Duration {
|
||||
if s.MappingBatchWatchdog > 0 {
|
||||
return s.MappingBatchWatchdog
|
||||
}
|
||||
return defaultMappingBatchWatchdog
|
||||
}
|
||||
|
||||
// processMappingsGuarded applies a batch under a watchdog, returning an error
|
||||
// if processing exceeds the watchdog so the caller reconnects and resyncs
|
||||
// instead of wedging silently.
|
||||
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
|
||||
batchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.processMappings(batchCtx, mappings)
|
||||
}()
|
||||
|
||||
watchdog := s.mappingBatchWatchdog()
|
||||
timer := time.NewTimer(watchdog)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
|
||||
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
@@ -1951,7 +2008,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
svcKey := s.serviceKeyForMapping(mapping)
|
||||
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
|
||||
removePeer := s.removePeer
|
||||
if removePeer == nil {
|
||||
removePeer = s.netbird.RemovePeer
|
||||
}
|
||||
if err := removePeer(ctx, accountID, svcKey); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": mapping.GetId(),
|
||||
|
||||
@@ -417,15 +417,30 @@ if type uname >/dev/null 2>&1; then
|
||||
# Check the availability of a compatible package manager
|
||||
if check_use_bin_variable; then
|
||||
PACKAGE_MANAGER="bin"
|
||||
elif [ -e /run/ostree-booted ]; then
|
||||
if [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v bootc)" ]; then
|
||||
echo "Detected bootc system without rpm-ostree." >&2
|
||||
echo "NetBird cannot be installed via package manager on this system." >&2
|
||||
echo "Options:" >&2
|
||||
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
|
||||
echo " 2. Rebuild your base image with rpm-ostree included" >&2
|
||||
echo " 3. Bake NetBird into your Containerfile" >&2
|
||||
exit 1
|
||||
else
|
||||
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
|
||||
echo "NetBird cannot be installed automatically on this atomic system." >&2
|
||||
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
|
||||
exit 1
|
||||
fi
|
||||
elif [ -x "$(command -v apt-get)" ]; then
|
||||
PACKAGE_MANAGER="apt"
|
||||
echo "The installation will be performed using apt package manager"
|
||||
elif [ -x "$(command -v dnf)" ]; then
|
||||
PACKAGE_MANAGER="dnf"
|
||||
echo "The installation will be performed using dnf package manager"
|
||||
elif [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v yum)" ]; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
echo "The installation will be performed using yum package manager"
|
||||
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
RoleKey = "role"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
UserAgentKey = "userAgent"
|
||||
)
|
||||
|
||||
@@ -5107,31 +5107,63 @@ components:
|
||||
responses:
|
||||
not_found:
|
||||
description: Resource not found
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed_simple:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
bad_request:
|
||||
description: Bad Request
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
internal_error:
|
||||
description: Internal Server Error
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
forbidden:
|
||||
description: Forbidden
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
requires_authentication:
|
||||
description: Requires authentication
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
conflict:
|
||||
description: Conflict
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
headers:
|
||||
X-Request-Id:
|
||||
description: |
|
||||
Unique identifier assigned to the request by the server and set on every
|
||||
response. Useful for correlating client requests with server-side logs.
|
||||
schema:
|
||||
type: string
|
||||
example: cot7r4n3l3vh3qj4qveg
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
|
||||
@@ -9,12 +9,14 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
@@ -172,6 +174,19 @@ type Client struct {
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
|
||||
// transportFallback, when set, records datagram-too-large failures so a
|
||||
// datagram-sized transport is avoided on subsequent connects. Shared via
|
||||
// the manager.
|
||||
transportFallback *transportFallback
|
||||
// datagramFallbackTriggered guards a single fallback per connection so a
|
||||
// burst of oversized datagrams triggers one reconnect, not many.
|
||||
datagramFallbackTriggered atomic.Bool
|
||||
}
|
||||
|
||||
// SetTransportFallback wires the shared datagram-transport fallback tracker.
|
||||
func (c *Client) SetTransportFallback(tf *transportFallback) {
|
||||
c.transportFallback = tf
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
@@ -361,12 +376,13 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
mode := transportModeFromEnv()
|
||||
dialers := c.getDialers(mode)
|
||||
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
conn, err = c.dialRaceDirect(ctx, mode, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
@@ -375,6 +391,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
@@ -382,6 +401,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
c.datagramFallbackTriggered.Store(false)
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
@@ -396,7 +416,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
@@ -406,6 +426,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
@@ -631,13 +654,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
conn := c.relayConn
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
if errors.Is(err, netErr.ErrDatagramTooLarge) {
|
||||
c.onDatagramTooLarge(conn, err)
|
||||
} else {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
|
||||
// When a non-datagram transport is available, it records a fallback for this
|
||||
// server and closes the connection so the reconnect avoids datagram-sized
|
||||
// transports. A single fallback is triggered per connection regardless of how
|
||||
// many oversized datagrams arrive. cause carries the datagram size and budget.
|
||||
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
|
||||
// Handle one oversized datagram per connection; a burst triggers a single
|
||||
// fallback (and a single log line), not many.
|
||||
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the selected mode offers no non-datagram transport (e.g. pinned to a
|
||||
// datagram-sized transport), reconnecting would just re-fail, so leave the
|
||||
// connection up rather than loop.
|
||||
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
|
||||
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
|
||||
return
|
||||
}
|
||||
|
||||
// Without the shared tracker a reconnect would just select the same
|
||||
// transport again and re-fail, so leave the connection up rather than loop.
|
||||
if c.transportFallback == nil {
|
||||
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
|
||||
return
|
||||
}
|
||||
|
||||
window := c.transportFallback.recordFailure(c.connectionURL)
|
||||
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.Debugf("close relay connection for transport fallback: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
|
||||
18
shared/relay/client/dialer/capability.go
Normal file
18
shared/relay/client/dialer/capability.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dialer
|
||||
|
||||
// DatagramSized is implemented by dialers whose connections carry each write in
|
||||
// a single datagram, so a write can be rejected when it exceeds the path's
|
||||
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
|
||||
// WebSocket over TCP) impose no per-write size limit, so the relay client can
|
||||
// fall back to them when a datagram-sized transport rejects a write as too
|
||||
// large. The capability is advertised per dialer rather than hardcoded, so a
|
||||
// new transport only needs to declare whether it is datagram-sized.
|
||||
type DatagramSized interface {
|
||||
DatagramSized()
|
||||
}
|
||||
|
||||
// IsDatagramSized reports whether d produces datagram-sized connections.
|
||||
func IsDatagramSized(d DialeFn) bool {
|
||||
_, ok := d.(DatagramSized)
|
||||
return ok
|
||||
}
|
||||
@@ -4,4 +4,9 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
|
||||
// ErrDatagramTooLarge is returned when a transport message exceeds the
|
||||
// QUIC datagram size the path to the relay can carry. The relay client
|
||||
// treats it as a signal to fall back to a non-datagram transport.
|
||||
ErrDatagramTooLarge = errors.New("datagram frame too large")
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
@@ -52,11 +51,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.writeErrHandling(err, len(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -95,3 +91,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
|
||||
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
|
||||
// datagram size and path budget) so the relay client can fall back to a
|
||||
// non-datagram transport.
|
||||
func (c *Conn) writeErrHandling(err error, size int) error {
|
||||
var tooLarge *quic.DatagramTooLargeError
|
||||
if errors.As(err, &tooLarge) {
|
||||
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
|
||||
}
|
||||
return c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -23,6 +24,12 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
|
||||
// carried in QUIC DATAGRAM frames, which must fit a single packet.
|
||||
func (d Dialer) DatagramSized() {
|
||||
// Intentional marker method; presence is the capability signal.
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
@@ -47,6 +54,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
Tracer: connectionTracer(quicURL),
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
@@ -74,6 +82,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
|
||||
// reason a relay connection closed, so the path MTU settled on and teardown
|
||||
// cause are visible in logs. Lines carry the relay address as a structured
|
||||
// field, matching the rest of the relay client logging.
|
||||
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
relayLog := log.WithField("relay", addr)
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
|
||||
if done {
|
||||
relayLog.Infof("QUIC path MTU settled at %d", mtu)
|
||||
return
|
||||
}
|
||||
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
|
||||
},
|
||||
ClosedConnection: func(err error) {
|
||||
relayLog.Debugf("QUIC connection closed: %v", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
@@ -32,6 +32,7 @@ type RaceDial struct {
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
sequential bool
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
@@ -53,7 +54,21 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
return r
|
||||
}
|
||||
|
||||
// WithSequential makes Dial try the dialers in order, falling back to the next
|
||||
// only when one fails to connect, instead of racing them concurrently.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithSequential() *RaceDial {
|
||||
r.sequential = true
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
if r.sequential {
|
||||
return r.dialSequential(ctx)
|
||||
}
|
||||
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(ctx)
|
||||
@@ -72,6 +87,30 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialSequential tries each dialer in order, returning the first connection and
|
||||
// falling back to the next on failure.
|
||||
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
|
||||
continue
|
||||
}
|
||||
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
|
||||
return conn, nil
|
||||
}
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialFallback(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
var firstDialed, secondDialed bool
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
firstDialed = true
|
||||
return nil, errors.New("quic unreachable")
|
||||
},
|
||||
}
|
||||
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
secondDialed = true
|
||||
return fallbackConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got %v", err)
|
||||
}
|
||||
if conn != fallbackConn {
|
||||
t.Errorf("expected fallback connection, got %v", conn)
|
||||
}
|
||||
if !firstDialed || !secondDialed {
|
||||
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialPreferredWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return preferredConn, nil
|
||||
},
|
||||
}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
t.Errorf("fallback dialer must not be tried when preferred succeeds")
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected preferred to succeed, got %v", err)
|
||||
}
|
||||
if conn != preferredConn {
|
||||
t.Errorf("expected preferred connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,42 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// getDialers returns the list of dialers to use for connecting to the relay server.
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
// getDialers returns the ordered dialers for connecting to the relay server. It
|
||||
// applies the datagram fallback generically: if this server recently rejected a
|
||||
// datagram-sized transport, those dialers are dropped, leaving the rest.
|
||||
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
|
||||
dialers := c.baseDialers(mode)
|
||||
|
||||
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
|
||||
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
|
||||
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
|
||||
return filtered
|
||||
}
|
||||
}
|
||||
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
return dialers
|
||||
}
|
||||
|
||||
// baseDialers returns the ordered dialers for the mode, before any datagram
|
||||
// fallback filtering. For racing modes (auto) the order is irrelevant; for
|
||||
// prefer modes the first entry is tried before falling back to the second.
|
||||
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
|
||||
switch mode {
|
||||
case TransportModeWS:
|
||||
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
case TransportModeQUIC:
|
||||
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{quic.Dialer{}}
|
||||
}
|
||||
|
||||
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
if mode == TransportModePreferWS {
|
||||
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
|
||||
}
|
||||
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
|
||||
return nonDatagramSized(all)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
101
shared/relay/client/dialers_generic_test.go
Normal file
101
shared/relay/client/dialers_generic_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build !js
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// TestDatagramSizedCapability locks the capability the generic fallback relies
|
||||
// on: QUIC is datagram-sized, WebSocket is not.
|
||||
func TestDatagramSizedCapability(t *testing.T) {
|
||||
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
|
||||
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
|
||||
}
|
||||
|
||||
func protocols(dialers []dialer.DialeFn) []string {
|
||||
out := make([]string, len(dialers))
|
||||
for i, d := range dialers {
|
||||
out[i] = d.Protocol()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestGetDialers(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
mtu uint16
|
||||
preferWS bool
|
||||
want []string
|
||||
}{
|
||||
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
|
||||
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
|
||||
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
|
||||
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.mode)
|
||||
if tc.mode == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
|
||||
tf := newTransportFallback()
|
||||
if tc.preferWS {
|
||||
tf.recordFailure(url)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: tc.mtu,
|
||||
transportFallback: tf,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
|
||||
// datagram records a fallback that makes the next dial pick WebSocket, the way a
|
||||
// reconnect would after the connection is closed.
|
||||
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: iface.DefaultMTU,
|
||||
transportFallback: newTransportFallback(),
|
||||
}
|
||||
|
||||
// First dial races both transports.
|
||||
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
|
||||
// An oversized datagram records the fallback for this server.
|
||||
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
|
||||
|
||||
// The reconnect now sticks to WebSocket.
|
||||
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
}
|
||||
@@ -7,7 +7,11 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
|
||||
// JS/WASM build only uses WebSocket transport
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -79,23 +79,30 @@ type Manager struct {
|
||||
|
||||
cleanupInterval time.Duration
|
||||
keepUnusedServerTime time.Duration
|
||||
|
||||
// transportFallback is shared across home and foreign relay clients so a
|
||||
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
|
||||
transportFallback *transportFallback
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
tf := newTransportFallback()
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
transportFallback: tf,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TransportFallback: tf,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
@@ -287,6 +294,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient.SetTransportFallback(m.transportFallback)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
|
||||
@@ -29,6 +29,7 @@ type ServerPicker struct {
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
TransportFallback *transportFallback
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
@@ -70,6 +71,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
relayClient.SetTransportFallback(sp.TransportFallback)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
|
||||
129
shared/relay/client/transport.go
Normal file
129
shared/relay/client/transport.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
)
|
||||
|
||||
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
|
||||
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
|
||||
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
|
||||
// the other only if it fails to connect; no race). The prefer modes trade a
|
||||
// slower connect when the preferred transport is blackholed for deterministic
|
||||
// transport selection.
|
||||
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
|
||||
|
||||
const (
|
||||
// transportFallbackBase is the initial window a relay server avoids
|
||||
// datagram-sized transports after a datagram is rejected as too large.
|
||||
transportFallbackBase = 10 * time.Minute
|
||||
// transportFallbackMax caps the pinned window when failures repeat.
|
||||
transportFallbackMax = 60 * time.Minute
|
||||
)
|
||||
|
||||
// TransportMode selects which relay dialers are used.
|
||||
type TransportMode string
|
||||
|
||||
const (
|
||||
TransportModeAuto TransportMode = "auto"
|
||||
TransportModeQUIC TransportMode = "quic"
|
||||
TransportModeWS TransportMode = "ws"
|
||||
TransportModePreferQUIC TransportMode = "prefer-quic"
|
||||
TransportModePreferWS TransportMode = "prefer-ws"
|
||||
)
|
||||
|
||||
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
|
||||
// or unrecognized value.
|
||||
func transportModeFromEnv() TransportMode {
|
||||
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
|
||||
case "", TransportModeAuto:
|
||||
return TransportModeAuto
|
||||
case TransportModeQUIC:
|
||||
return TransportModeQUIC
|
||||
case TransportModeWS:
|
||||
return TransportModeWS
|
||||
case TransportModePreferQUIC:
|
||||
return TransportModePreferQUIC
|
||||
case TransportModePreferWS:
|
||||
return TransportModePreferWS
|
||||
default:
|
||||
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
|
||||
return TransportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// sequential reports whether the mode tries dialers in order with fallback
|
||||
// instead of racing them concurrently.
|
||||
func (m TransportMode) sequential() bool {
|
||||
return m == TransportModePreferQUIC || m == TransportModePreferWS
|
||||
}
|
||||
|
||||
// transportFallback tracks relay servers that have rejected a datagram-sized
|
||||
// transport (a write too large for the path) and should temporarily avoid such
|
||||
// transports. It is shared across the relay manager so the preference survives
|
||||
// client recreation (foreign relay clients are evicted and rebuilt on
|
||||
// disconnect). Entries are keyed by server URL and expire after a window that
|
||||
// grows on repeated failures.
|
||||
type transportFallback struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*fallbackEntry
|
||||
}
|
||||
|
||||
type fallbackEntry struct {
|
||||
until time.Time
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func newTransportFallback() *transportFallback {
|
||||
return &transportFallback{entries: make(map[string]*fallbackEntry)}
|
||||
}
|
||||
|
||||
// avoidDatagramSized reports whether serverURL is currently within a window
|
||||
// where datagram-sized transports should be avoided.
|
||||
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
e := f.entries[serverURL]
|
||||
return e != nil && time.Now().Before(e.until)
|
||||
}
|
||||
|
||||
// recordFailure makes serverURL avoid datagram-sized transports for a window:
|
||||
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
|
||||
// when a datagram transport fails again after a previous window expired. It
|
||||
// returns the active window duration.
|
||||
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
e := f.entries[serverURL]
|
||||
switch {
|
||||
case e == nil:
|
||||
e = &fallbackEntry{duration: transportFallbackBase}
|
||||
f.entries[serverURL] = e
|
||||
case now.Before(e.until):
|
||||
return time.Until(e.until)
|
||||
default:
|
||||
e.duration = min(e.duration*2, transportFallbackMax)
|
||||
}
|
||||
e.until = now.Add(e.duration)
|
||||
return e.duration
|
||||
}
|
||||
|
||||
// nonDatagramSized returns the dialers from in that are not datagram-sized,
|
||||
// preserving order.
|
||||
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
|
||||
out := make([]dialer.DialeFn, 0, len(in))
|
||||
for _, d := range in {
|
||||
if !dialer.IsDatagramSized(d) {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
140
shared/relay/client/transport_test.go
Normal file
140
shared/relay/client/transport_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
// closeTrackingConn records whether Close was called; only Close is exercised.
|
||||
type closeTrackingConn struct {
|
||||
net.Conn
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *closeTrackingConn) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTransportModeFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
want TransportMode
|
||||
}{
|
||||
{"", TransportModeAuto},
|
||||
{"auto", TransportModeAuto},
|
||||
{"quic", TransportModeQUIC},
|
||||
{"QUIC", TransportModeQUIC},
|
||||
{"ws", TransportModeWS},
|
||||
{" Ws ", TransportModeWS},
|
||||
{"prefer-quic", TransportModePreferQUIC},
|
||||
{"prefer-ws", TransportModePreferWS},
|
||||
{"garbage", TransportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.value, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.value)
|
||||
if tc.value == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
assert.Equal(t, tc.want, transportModeFromEnv())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
|
||||
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
|
||||
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
|
||||
|
||||
// A second failure while still inside the window must not grow the window.
|
||||
d = f.recordFailure(url)
|
||||
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
|
||||
require.NotNil(t, f.entries[url])
|
||||
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
|
||||
|
||||
// Expire the window: datagram-sized transport allowed again.
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
|
||||
}
|
||||
|
||||
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
want := transportFallbackBase
|
||||
for i := range 6 {
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, want, d, "window after %d expiries", i)
|
||||
|
||||
// expire the window so the next failure is treated as a repeat
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
|
||||
want = min(want*2, transportFallbackMax)
|
||||
}
|
||||
|
||||
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeAuto(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.True(t, conn.closed, "connection closed to force reconnect")
|
||||
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
|
||||
|
||||
// A second oversized datagram on the same connection must not re-close.
|
||||
conn.closed = false
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
assert.False(t, conn.closed, "single fallback per connection")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
|
||||
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
|
||||
}
|
||||
|
||||
func TestTransportFallbackPerServer(t *testing.T) {
|
||||
f := newTransportFallback()
|
||||
f.recordFailure("rels://a.example:443")
|
||||
|
||||
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
|
||||
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
|
||||
}
|
||||
Reference in New Issue
Block a user