mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-10 09:59:55 +00:00
Compare commits
17 Commits
profile-id
...
windows-dn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9157b74945 | ||
|
|
a40028092d | ||
|
|
13200265d8 | ||
|
|
ed7a9363aa | ||
|
|
d56859dc5d | ||
|
|
367d37050b | ||
|
|
106527182f | ||
|
|
8e1d5b78c2 | ||
|
|
d3b63c6be9 | ||
|
|
cca46f070b | ||
|
|
5f8b88471f | ||
|
|
f42b8aed90 | ||
|
|
0415137acd | ||
|
|
7fd16666e3 | ||
|
|
0571eeaba0 | ||
|
|
6a201d12b5 | ||
|
|
4810e79a00 |
@@ -92,6 +92,9 @@ linters:
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
|
||||
@@ -3,6 +3,7 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -749,11 +750,17 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
63
client/internal/dns/dnsfw/config.go
Normal file
63
client/internal/dns/dnsfw/config.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
|
||||
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
|
||||
// EnvPorts overrides the comma-separated list of remote ports to block.
|
||||
// Empty disables the firewall.
|
||||
EnvPorts = "NB_DNS_FIREWALL_PORTS"
|
||||
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
|
||||
// and the netbird daemon. Default mode also permits anything on the
|
||||
// netbird tunnel interface, which is safer if NRPT is silently ignored
|
||||
// by Windows but lets apps reach custom DNS servers via the tunnel.
|
||||
EnvStrict = "NB_DNS_FIREWALL_STRICT"
|
||||
)
|
||||
|
||||
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
|
||||
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
|
||||
var defaultBlockedPorts = []uint16{53, 853}
|
||||
|
||||
// blockedPorts returns the effective port list, honoring env overrides.
|
||||
// A nil return means the firewall should not be installed.
|
||||
func blockedPorts() []uint16 {
|
||||
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
|
||||
log.Infof("dns firewall disabled via %s", EnvDisable)
|
||||
return nil
|
||||
}
|
||||
|
||||
override, ok := os.LookupEnv(EnvPorts)
|
||||
if !ok {
|
||||
return defaultBlockedPorts
|
||||
}
|
||||
|
||||
var ports []uint16
|
||||
for _, raw := range strings.Split(override, ",") {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.ParseUint(raw, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
|
||||
continue
|
||||
}
|
||||
if port == 0 {
|
||||
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
|
||||
continue
|
||||
}
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
|
||||
return nil
|
||||
}
|
||||
return ports
|
||||
}
|
||||
39
client/internal/dns/dnsfw/config_test.go
Normal file
39
client/internal/dns/dnsfw/config_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBlockedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
disable string
|
||||
ports string
|
||||
setPorts bool
|
||||
want []uint16
|
||||
}{
|
||||
{name: "default", want: defaultBlockedPorts},
|
||||
{name: "disabled", disable: "true", want: nil},
|
||||
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
|
||||
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
|
||||
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
|
||||
{name: "override empty disables", ports: "", setPorts: true, want: nil},
|
||||
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvDisable, tc.disable)
|
||||
if tc.setPorts {
|
||||
t.Setenv(EnvPorts, tc.ports)
|
||||
}
|
||||
got := blockedPorts()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
client/internal/dns/dnsfw/dnsfw.go
Normal file
16
client/internal/dns/dnsfw/dnsfw.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
|
||||
// managing the host's DNS, so that resolvers running on apps or libraries
|
||||
// outside netbird cannot bypass the configured DNS path.
|
||||
//
|
||||
// Implementation is Windows-only (uses WFP). On other platforms New returns
|
||||
// a no-op manager.
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
|
||||
// to call multiple times.
|
||||
type Manager interface {
|
||||
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
|
||||
Disable() error
|
||||
}
|
||||
15
client/internal/dns/dnsfw/dnsfw_other.go
Normal file
15
client/internal/dns/dnsfw/dnsfw_other.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type noopManager struct{}
|
||||
|
||||
func (noopManager) Enable(string, netip.Addr) error { return nil }
|
||||
func (noopManager) Disable() error { return nil }
|
||||
|
||||
// New returns a no-op manager on non-Windows platforms.
|
||||
func New() Manager {
|
||||
return noopManager{}
|
||||
}
|
||||
144
client/internal/dns/dnsfw/dnsfw_windows.go
Normal file
144
client/internal/dns/dnsfw/dnsfw_windows.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
|
||||
)
|
||||
|
||||
type windowsManager struct {
|
||||
mu sync.Mutex
|
||||
// session is the WFP engine handle. Zero when disabled.
|
||||
session uintptr
|
||||
}
|
||||
|
||||
// Enable installs the dns firewall. Strict mode propagates failures;
|
||||
// non-strict mode logs and returns nil so partial protection is preserved.
|
||||
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
ports := blockedPorts()
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.session != 0 {
|
||||
if err := m.disableLocked(); err != nil {
|
||||
return fmt.Errorf("reset existing dns firewall session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
strict := strictMode()
|
||||
|
||||
luid, err := luidFromGUID(ifaceGUID)
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
|
||||
}
|
||||
|
||||
cfg := installConfig{
|
||||
tunLUID: luid,
|
||||
daemonExe: exe,
|
||||
blockedPorts: ports,
|
||||
strict: strict,
|
||||
virtualDNSIP: virtualDNSIP,
|
||||
}
|
||||
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
|
||||
session, installErr := installFilters(cfg)
|
||||
if session == 0 {
|
||||
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
|
||||
}
|
||||
|
||||
if installErr != nil && strict {
|
||||
_ = closeSession(session)
|
||||
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
|
||||
}
|
||||
|
||||
m.session = session
|
||||
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
|
||||
ifaceGUID, exe, ports, strict, virtualDNSIP)
|
||||
if installErr != nil {
|
||||
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *windowsManager) Disable() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.disableLocked()
|
||||
}
|
||||
|
||||
func (m *windowsManager) disableLocked() error {
|
||||
if m.session == 0 {
|
||||
return nil
|
||||
}
|
||||
session := m.session
|
||||
m.session = 0
|
||||
if err := closeSession(session); err != nil {
|
||||
return fmt.Errorf("close wfp session: %w", err)
|
||||
}
|
||||
log.Info("dns firewall removed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// failOrLog returns err unchanged in strict mode. In non-strict mode the
|
||||
// error is logged and nil is returned.
|
||||
func (m *windowsManager) failOrLog(strict bool, err error) error {
|
||||
if strict {
|
||||
return err
|
||||
}
|
||||
log.Errorf("dns firewall: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a Windows DNS firewall manager backed by WFP.
|
||||
func New() Manager {
|
||||
return &windowsManager{}
|
||||
}
|
||||
|
||||
// strictMode reports whether strict mode is enabled via env.
|
||||
func strictMode() bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
|
||||
return v
|
||||
}
|
||||
|
||||
// luidFromGUID converts a Windows interface GUID string to its LUID.
|
||||
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in luidFromGUID: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
guid, err := windows.GUIDFromString(ifaceGUID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse guid: %w", err)
|
||||
}
|
||||
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
if rc != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
|
||||
}
|
||||
return luid, nil
|
||||
}
|
||||
72
client/internal/dns/dnsfw/dnsfw_windows_test.go
Normal file
72
client/internal/dns/dnsfw/dnsfw_windows_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrictMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
set bool
|
||||
want bool
|
||||
}{
|
||||
{name: "unset", want: false},
|
||||
{name: "true", val: "true", set: true, want: true},
|
||||
{name: "1", val: "1", set: true, want: true},
|
||||
{name: "false", val: "false", set: true, want: false},
|
||||
{name: "invalid is false", val: "garbage", set: true, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvStrict, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(EnvStrict)
|
||||
}
|
||||
if got := strictMode(); got != tc.want {
|
||||
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerDisableIdempotent(t *testing.T) {
|
||||
m := &windowsManager{}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("first Disable on fresh manager: %v", err)
|
||||
}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("second Disable on fresh manager: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session should remain zero, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
|
||||
t.Setenv(EnvDisable, "true")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
|
||||
t.Setenv(EnvPorts, "")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
|
||||
}
|
||||
}
|
||||
53
client/internal/dns/dnsfw/helpers_windows.go
Normal file
53
client/internal/dns/dnsfw/helpers_windows.go
Normal file
@@ -0,0 +1,53 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
descriptionPtr, err := windows.UTF16PtrFromString(description)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
return &wtFwpmDisplayData0{
|
||||
name: namePtr,
|
||||
description: descriptionPtr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func filterWeight(weight uint8) wtFwpValue0 {
|
||||
return wtFwpValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(weight),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapErr(err error) error {
|
||||
var errno syscall.Errno
|
||||
if !errors.As(err, &errno) {
|
||||
return err
|
||||
}
|
||||
_, file, line, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
return fmt.Errorf("wfp error at unknown location: %w", err)
|
||||
}
|
||||
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
|
||||
}
|
||||
249
client/internal/dns/dnsfw/rules_windows.go
Normal file
249
client/internal/dns/dnsfw/rules_windows.go
Normal file
@@ -0,0 +1,249 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
|
||||
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
|
||||
* uses for its kill-switch DNS leak protection. We extend it with a
|
||||
* configurable port set so we also cover :853 (DoT) and any future ports.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
|
||||
// follow the authorized outbound flow.
|
||||
|
||||
// permitTunInterface installs a permit filter for any traffic whose local
|
||||
// interface is the netbird tunnel.
|
||||
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT64,
|
||||
value: uintptr(unsafe.Pointer(&ifLUID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
|
||||
}
|
||||
|
||||
// permitDaemonByAppID installs a permit filter matching the netbird daemon
|
||||
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
|
||||
// dedicated binary.
|
||||
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
|
||||
appID, err := daemonAppID(daemonExe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
|
||||
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_BLOB_TYPE,
|
||||
value: uintptr(unsafe.Pointer(appID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird daemon")
|
||||
}
|
||||
|
||||
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
|
||||
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
|
||||
// permitTunInterface.
|
||||
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
|
||||
if !ip.IsValid() {
|
||||
return fmt.Errorf("invalid address")
|
||||
}
|
||||
|
||||
var addrCond wtFwpmFilterCondition0
|
||||
var layer windows.GUID
|
||||
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
|
||||
var v6 wtFwpByteArray16
|
||||
|
||||
if ip.Is4() {
|
||||
v4 := ip.As4()
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT32,
|
||||
value: uintptr(binary.BigEndian.Uint32(v4[:])),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
|
||||
} else {
|
||||
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_ARRAY16_TYPE,
|
||||
value: uintptr(unsafe.Pointer(&v6)),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
|
||||
}
|
||||
|
||||
conditions := [2]wtFwpmFilterCondition0{
|
||||
addrCond,
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
_ = v6
|
||||
return nil
|
||||
}
|
||||
|
||||
// blockDNSPorts installs a deny filter for outbound traffic to each of the
|
||||
// given remote ports over UDP or TCP. Per-port and per-layer failures are
|
||||
// accumulated; partial coverage is preferred over zero coverage.
|
||||
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := blockDNSPort(session, base, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
|
||||
conditions := [3]wtFwpmFilterCondition0{
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_UDP),
|
||||
},
|
||||
},
|
||||
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_TCP),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
|
||||
}
|
||||
|
||||
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
|
||||
// connect layers. v4 and v6 are installed independently: failure on one
|
||||
// layer does not abort the other, and the accumulated errors are returned.
|
||||
// Partial coverage is preferred over zero coverage.
|
||||
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
|
||||
layers := [...]struct {
|
||||
layer windows.GUID
|
||||
label string
|
||||
}{
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, l := range layers {
|
||||
display, err := createWtFwpmDisplayData0(l.label, "")
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
continue
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = l.layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
177
client/internal/dns/dnsfw/session_windows.go
Normal file
177
client/internal/dns/dnsfw/session_windows.go
Normal file
@@ -0,0 +1,177 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Session lifecycle and the high-level Install/Close entry points adapted
|
||||
* from wireguard-windows tunnel/firewall.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// installConfig is the input to installFilters.
|
||||
type installConfig struct {
|
||||
tunLUID uint64
|
||||
daemonExe string
|
||||
blockedPorts []uint16
|
||||
// strict, when true, narrows the carve-out from "anything on tun" to
|
||||
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
|
||||
strict bool
|
||||
virtualDNSIP netip.Addr
|
||||
}
|
||||
|
||||
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
|
||||
// for our session. Both are randomly generated per session.
|
||||
type baseObjects struct {
|
||||
provider windows.GUID
|
||||
filters windows.GUID
|
||||
}
|
||||
|
||||
// installFilters opens a dynamic WFP session and installs the netbird DNS
|
||||
// firewall filters. Returns a zero session on hard failure (session create,
|
||||
// base objects); a non-zero session with a non-nil error is a partial install
|
||||
// (some per-filter installs failed) and is safe to close.
|
||||
func installFilters(cfg installConfig) (session uintptr, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Dynamic session: kernel will clean up on process exit even
|
||||
// if we leave the handle dangling here.
|
||||
err = fmt.Errorf("panic in installFilters: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(cfg.blockedPorts) == 0 {
|
||||
return 0, errors.New("dns firewall: no blocked ports configured")
|
||||
}
|
||||
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
|
||||
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
|
||||
}
|
||||
|
||||
session, err = createSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
base, err := registerBaseObjects(session)
|
||||
if err != nil {
|
||||
_ = fwpmEngineClose0(session)
|
||||
return 0, fmt.Errorf("register base objects: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if cfg.strict {
|
||||
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
|
||||
}
|
||||
}
|
||||
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
|
||||
}
|
||||
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
|
||||
}
|
||||
|
||||
return session, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// closeSession tears down a WFP session previously opened by installFilters.
|
||||
// All filters owned by the session are removed.
|
||||
func closeSession(session uintptr) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in closeSession: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if session == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := fwpmEngineClose0(session); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession() (uintptr, error) {
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
|
||||
if err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
session := wtFwpmSession0{
|
||||
displayData: *displayData,
|
||||
flags: cFWPM_SESSION_FLAG_DYNAMIC,
|
||||
txnWaitTimeoutInMSec: windows.INFINITE,
|
||||
}
|
||||
var handle uintptr
|
||||
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func registerBaseObjects(session uintptr) (*baseObjects, error) {
|
||||
bo := &baseObjects{}
|
||||
var err error
|
||||
if bo.provider, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
if bo.filters, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
provider := wtFwpmProvider0{
|
||||
providerKey: bo.provider,
|
||||
displayData: *displayData,
|
||||
}
|
||||
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
sublayer := wtFwpmSublayer0{
|
||||
subLayerKey: bo.filters,
|
||||
displayData: *subDisplay,
|
||||
providerKey: &bo.provider,
|
||||
weight: ^uint16(0),
|
||||
}
|
||||
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return bo, nil
|
||||
}
|
||||
|
||||
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
|
||||
func daemonAppID(path string) (*wtFwpByteBlob, error) {
|
||||
pathPtr, err := windows.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
var appID *wtFwpByteBlob
|
||||
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return appID, nil
|
||||
}
|
||||
38
client/internal/dns/dnsfw/syscall_windows.go
Normal file
38
client/internal/dns/dnsfw/syscall_windows.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
|
||||
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
|
||||
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
|
||||
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
|
||||
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
|
||||
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
|
||||
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
|
||||
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
|
||||
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
|
||||
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
|
||||
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0
|
||||
414
client/internal/dns/dnsfw/types_windows.go
Normal file
414
client/internal/dns/dnsfw/types_windows.go
Normal file
@@ -0,0 +1,414 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
|
||||
|
||||
wtFwpBitmapArray64_Size = 8
|
||||
|
||||
wtFwpByteArray16_Size = 16
|
||||
|
||||
wtFwpByteArray6_Size = 6
|
||||
|
||||
wtFwpmAction0_Size = 20
|
||||
wtFwpmAction0_filterType_Offset = 4
|
||||
|
||||
wtFwpV4AddrAndMask_Size = 8
|
||||
wtFwpV4AddrAndMask_mask_Offset = 4
|
||||
|
||||
wtFwpV6AddrAndMask_Size = 17
|
||||
wtFwpV6AddrAndMask_prefixLength_Offset = 16
|
||||
)
|
||||
|
||||
type wtFwpActionFlag uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
|
||||
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
|
||||
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
|
||||
)
|
||||
|
||||
// FWP_ACTION_TYPE defined in fwptypes.h
|
||||
type wtFwpActionType uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
|
||||
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
|
||||
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
|
||||
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
|
||||
)
|
||||
|
||||
// FWP_BYTE_BLOB defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
|
||||
type wtFwpByteBlob struct {
|
||||
size uint32
|
||||
data *uint8
|
||||
}
|
||||
|
||||
// FWP_MATCH_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
|
||||
type wtFwpMatchType uint32
|
||||
|
||||
const (
|
||||
cFWP_MATCH_EQUAL wtFwpMatchType = 0
|
||||
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
|
||||
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
|
||||
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
|
||||
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
|
||||
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
|
||||
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
|
||||
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
|
||||
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
|
||||
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
|
||||
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
|
||||
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
|
||||
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
|
||||
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
|
||||
)
|
||||
|
||||
// FWPM_ACTION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
|
||||
type wtFwpmAction0 struct {
|
||||
_type wtFwpActionType
|
||||
filterType windows.GUID // Windows type: GUID
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
||||
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
|
||||
Data1: 0x4cd62a49,
|
||||
Data2: 0x59c3,
|
||||
Data3: 0x4969,
|
||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
|
||||
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
|
||||
Data1: 0xb235ae9a,
|
||||
Data2: 0x1d64,
|
||||
Data3: 0x49b8,
|
||||
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
||||
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
|
||||
Data1: 0x3971ef2b,
|
||||
Data2: 0x623e,
|
||||
Data3: 0x4f9a,
|
||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
|
||||
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
|
||||
Data1: 0x0c1ba1af,
|
||||
Data2: 0x5765,
|
||||
Data3: 0x453f,
|
||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
|
||||
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
|
||||
Data1: 0xc35a604d,
|
||||
Data2: 0xd22b,
|
||||
Data3: 0x4e1a,
|
||||
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
|
||||
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
|
||||
Data1: 0xd78e1e87,
|
||||
Data2: 0x8644,
|
||||
Data3: 0x4ea5,
|
||||
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
|
||||
}
|
||||
|
||||
// af043a0a-b34d-4f86-979c-c90371af6e66
|
||||
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
|
||||
Data1: 0xaf043a0a,
|
||||
Data2: 0xb34d,
|
||||
Data3: 0x4f86,
|
||||
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
|
||||
}
|
||||
|
||||
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
|
||||
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
|
||||
Data1: 0xd9ee00de,
|
||||
Data2: 0xc1ef,
|
||||
Data3: 0x4617,
|
||||
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
|
||||
}
|
||||
|
||||
var (
|
||||
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
|
||||
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
|
||||
)
|
||||
|
||||
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
|
||||
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
|
||||
Data1: 0x7bc43cbf,
|
||||
Data2: 0x37ba,
|
||||
Data3: 0x45f1,
|
||||
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
|
||||
}
|
||||
|
||||
type wtFwpmL2Flags uint32
|
||||
|
||||
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
|
||||
|
||||
var cFWPM_CONDITION_FLAGS = windows.GUID{
|
||||
Data1: 0x632ce23b,
|
||||
Data2: 0x5167,
|
||||
Data3: 0x435c,
|
||||
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
|
||||
}
|
||||
|
||||
type wtFwpmFlags uint32
|
||||
|
||||
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
|
||||
|
||||
// Defined in fwpmtypes.h
|
||||
type wtFwpmFilterFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
|
||||
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
|
||||
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
|
||||
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
|
||||
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
|
||||
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
|
||||
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
|
||||
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
|
||||
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
|
||||
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
|
||||
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
|
||||
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
|
||||
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
|
||||
)
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
|
||||
Data1: 0xc38d57d1,
|
||||
Data2: 0x05a7,
|
||||
Data3: 0x4c33,
|
||||
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
|
||||
}
|
||||
|
||||
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
|
||||
Data1: 0xe1cd9fe7,
|
||||
Data2: 0xf4b5,
|
||||
Data3: 0x4273,
|
||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
||||
}
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
|
||||
Data1: 0x4a72393b,
|
||||
Data2: 0x319f,
|
||||
Data3: 0x44bc,
|
||||
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
|
||||
}
|
||||
|
||||
// a3b42c97-9f04-4672-b87e-cee9c483257f
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
|
||||
Data1: 0xa3b42c97,
|
||||
Data2: 0x9f04,
|
||||
Data3: 0x4672,
|
||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
||||
}
|
||||
|
||||
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
|
||||
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0x94c44912,
|
||||
Data2: 0x9d6f,
|
||||
Data3: 0x4ebf,
|
||||
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
|
||||
}
|
||||
|
||||
// d4220bd3-62ce-4f08-ae88-b56e8526df50
|
||||
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0xd4220bd3,
|
||||
Data2: 0x62ce,
|
||||
Data3: 0x4f08,
|
||||
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
|
||||
}
|
||||
|
||||
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
|
||||
type wtFwpBitmapArray64 struct {
|
||||
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY6 defined in fwtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
|
||||
type wtFwpByteArray6 struct {
|
||||
byteArray6 [6]uint8 // Windows type: [6]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY16 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
|
||||
type wtFwpByteArray16 struct {
|
||||
byteArray16 [16]uint8 // Windows type [16]UINT8
|
||||
}
|
||||
|
||||
// FWP_CONDITION_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
|
||||
type wtFwpConditionValue0 wtFwpValue0
|
||||
|
||||
// FWP_DATA_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
|
||||
type wtFwpDataType uint
|
||||
|
||||
const (
|
||||
cFWP_EMPTY wtFwpDataType = 0
|
||||
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
|
||||
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
|
||||
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
|
||||
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
|
||||
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
|
||||
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
|
||||
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
|
||||
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
|
||||
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
|
||||
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
|
||||
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
|
||||
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
|
||||
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
|
||||
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
|
||||
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
|
||||
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
|
||||
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
|
||||
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
|
||||
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
|
||||
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
|
||||
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
|
||||
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
|
||||
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
|
||||
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
|
||||
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
|
||||
)
|
||||
|
||||
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
|
||||
type wtFwpV4AddrAndMask struct {
|
||||
addr uint32
|
||||
mask uint32
|
||||
}
|
||||
|
||||
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
|
||||
type wtFwpV6AddrAndMask struct {
|
||||
addr [16]uint8
|
||||
prefixLength uint8
|
||||
}
|
||||
|
||||
// FWP_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
|
||||
type wtFwpValue0 struct {
|
||||
_type wtFwpDataType
|
||||
value uintptr
|
||||
}
|
||||
|
||||
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
|
||||
type wtFwpmDisplayData0 struct {
|
||||
name *uint16 // Windows type: *wchar_t
|
||||
description *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
|
||||
type wtFwpmFilterCondition0 struct {
|
||||
fieldKey windows.GUID // Windows type: GUID
|
||||
matchType wtFwpMatchType
|
||||
conditionValue wtFwpConditionValue0
|
||||
}
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
|
||||
type wtFwpProvider0 struct {
|
||||
providerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
type wtFwpmSessionFlagsValue uint32
|
||||
|
||||
const (
|
||||
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SESSION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
|
||||
type wtFwpmSession0 struct {
|
||||
sessionKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSessionFlagsValue // Windows type UINT32
|
||||
txnWaitTimeoutInMSec uint32
|
||||
processId uint32 // Windows type: DWORD
|
||||
sid *windows.SID
|
||||
username *uint16 // Windows type: *wchar_t
|
||||
kernelMode uint8 // Windows type: BOOL
|
||||
}
|
||||
|
||||
type wtFwpmSublayerFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SUBLAYER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
|
||||
type wtFwpmSublayer0 struct {
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSublayerFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
weight uint16
|
||||
}
|
||||
|
||||
// Defined in rpcdce.h
|
||||
type wtRpcCAuthN uint32
|
||||
|
||||
const (
|
||||
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
|
||||
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
|
||||
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
|
||||
type wtFwpmProvider0 struct {
|
||||
providerKey windows.GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16
|
||||
}
|
||||
|
||||
type wtIPProto uint32
|
||||
|
||||
const (
|
||||
cIPPROTO_ICMP wtIPProto = 1
|
||||
cIPPROTO_ICMPV6 wtIPProto = 58
|
||||
cIPPROTO_TCP wtIPProto = 6
|
||||
cIPPROTO_UDP wtIPProto = 17
|
||||
)
|
||||
|
||||
const (
|
||||
cFWP_ACTRL_MATCH_FILTER = 1
|
||||
)
|
||||
92
client/internal/dns/dnsfw/types_windows_32.go
Normal file
92
client/internal/dns/dnsfw/types_windows_32.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build windows && (386 || arm)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 8
|
||||
wtFwpByteBlob_data_Offset = 4
|
||||
|
||||
wtFwpConditionValue0_Size = 8
|
||||
wtFwpConditionValue0_uint8_Offset = 4
|
||||
|
||||
wtFwpmDisplayData0_Size = 8
|
||||
wtFwpmDisplayData0_description_Offset = 4
|
||||
|
||||
wtFwpmFilter0_Size = 152
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 24
|
||||
wtFwpmFilter0_providerKey_Offset = 28
|
||||
wtFwpmFilter0_providerData_Offset = 32
|
||||
wtFwpmFilter0_layerKey_Offset = 40
|
||||
wtFwpmFilter0_subLayerKey_Offset = 56
|
||||
wtFwpmFilter0_weight_Offset = 72
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 80
|
||||
wtFwpmFilter0_filterCondition_Offset = 84
|
||||
wtFwpmFilter0_action_Offset = 88
|
||||
wtFwpmFilter0_providerContextKey_Offset = 112
|
||||
wtFwpmFilter0_reserved_Offset = 128
|
||||
wtFwpmFilter0_filterID_Offset = 136
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 144
|
||||
|
||||
wtFwpmFilterCondition0_Size = 28
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 20
|
||||
|
||||
wtFwpmSession0_Size = 48
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 24
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
|
||||
wtFwpmSession0_processId_Offset = 32
|
||||
wtFwpmSession0_sid_Offset = 36
|
||||
wtFwpmSession0_username_Offset = 40
|
||||
wtFwpmSession0_kernelMode_Offset = 44
|
||||
|
||||
wtFwpmSublayer0_Size = 44
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 24
|
||||
wtFwpmSublayer0_providerKey_Offset = 28
|
||||
wtFwpmSublayer0_providerData_Offset = 32
|
||||
wtFwpmSublayer0_weight_Offset = 40
|
||||
|
||||
wtFwpProvider0_Size = 40
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 24
|
||||
wtFwpProvider0_providerData_Offset = 28
|
||||
wtFwpProvider0_serviceName_Offset = 36
|
||||
|
||||
wtFwpTokenInformation_Size = 16
|
||||
|
||||
wtFwpValue0_Size = 8
|
||||
wtFwpValue0_value_Offset = 4
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
offset2 [4]byte // Layout correction field
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
89
client/internal/dns/dnsfw/types_windows_64.go
Normal file
89
client/internal/dns/dnsfw/types_windows_64.go
Normal file
@@ -0,0 +1,89 @@
|
||||
//go:build windows && (amd64 || arm64)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 16
|
||||
wtFwpByteBlob_data_Offset = 8
|
||||
|
||||
wtFwpConditionValue0_Size = 16
|
||||
wtFwpConditionValue0_uint8_Offset = 8
|
||||
|
||||
wtFwpmDisplayData0_Size = 16
|
||||
wtFwpmDisplayData0_description_Offset = 8
|
||||
|
||||
wtFwpmFilter0_Size = 200
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 32
|
||||
wtFwpmFilter0_providerKey_Offset = 40
|
||||
wtFwpmFilter0_providerData_Offset = 48
|
||||
wtFwpmFilter0_layerKey_Offset = 64
|
||||
wtFwpmFilter0_subLayerKey_Offset = 80
|
||||
wtFwpmFilter0_weight_Offset = 96
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 112
|
||||
wtFwpmFilter0_filterCondition_Offset = 120
|
||||
wtFwpmFilter0_action_Offset = 128
|
||||
wtFwpmFilter0_providerContextKey_Offset = 152
|
||||
wtFwpmFilter0_reserved_Offset = 168
|
||||
wtFwpmFilter0_filterID_Offset = 176
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 184
|
||||
|
||||
wtFwpmFilterCondition0_Size = 40
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 24
|
||||
|
||||
wtFwpmSession0_Size = 72
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 32
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
|
||||
wtFwpmSession0_processId_Offset = 40
|
||||
wtFwpmSession0_sid_Offset = 48
|
||||
wtFwpmSession0_username_Offset = 56
|
||||
wtFwpmSession0_kernelMode_Offset = 64
|
||||
|
||||
wtFwpmSublayer0_Size = 72
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 32
|
||||
wtFwpmSublayer0_providerKey_Offset = 40
|
||||
wtFwpmSublayer0_providerData_Offset = 48
|
||||
wtFwpmSublayer0_weight_Offset = 64
|
||||
|
||||
wtFwpProvider0_Size = 64
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 32
|
||||
wtFwpProvider0_providerData_Offset = 40
|
||||
wtFwpProvider0_serviceName_Offset = 56
|
||||
|
||||
wtFwpValue0_Size = 16
|
||||
wtFwpValue0_value_Offset = 8
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags // Windows type: UINT32
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
130
client/internal/dns/dnsfw/zsyscall_windows.go
Normal file
130
client/internal/dns/dnsfw/zsyscall_windows.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
||||
|
||||
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
|
||||
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
|
||||
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
|
||||
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
|
||||
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
|
||||
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
|
||||
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
|
||||
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
|
||||
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
|
||||
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
|
||||
)
|
||||
|
||||
func fwpmEngineClose0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFreeMemory0(p unsafe.Pointer) {
|
||||
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
@@ -74,6 +75,7 @@ type registryConfigurator struct {
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
dnsFirewall dnsfw.Manager
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
@@ -94,8 +96,9 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
}
|
||||
|
||||
configurator := ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
@@ -276,16 +279,8 @@ func (r *registryConfigurator) disableWINSForInterface() error {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if config.RouteAll {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
if err := r.applyRouteAll(config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
@@ -327,6 +322,35 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
|
||||
if config.RouteAll {
|
||||
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
|
||||
return fmt.Errorf("dns firewall: %w", err)
|
||||
}
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
|
||||
if dErr := r.dnsFirewall.Disable(); dErr != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
if !r.routingAll {
|
||||
return nil
|
||||
}
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
@@ -513,6 +537,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
@@ -34,8 +36,9 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
@@ -134,8 +137,9 @@ func TestNRPTDomainBatching(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
@@ -700,6 +700,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
71
client/internal/routemanager/selector_management_test.go
Normal file
71
client/internal/routemanager/selector_management_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,6 +116,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -99,6 +99,9 @@ func addFields(entry *logrus.Entry) {
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
|
||||
entry.Data[context.UserAgentKey] = ctxUserAgent
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
|
||||
@@ -19,6 +19,46 @@ readonly MSG_SEPARATOR="=========================================="
|
||||
# Utility Functions
|
||||
############################################
|
||||
|
||||
check_docker_sock_perms() {
|
||||
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
|
||||
sock="${sock#unix://}"
|
||||
|
||||
if [[ ! -S "$sock" ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
|
||||
local group
|
||||
if [[ "${OSTYPE}" == "darwin"* ]]; then
|
||||
group="$(stat -f '%Sg' "$sock")"
|
||||
else
|
||||
group="$(stat -c '%G' "$sock")"
|
||||
fi
|
||||
|
||||
echo "Cannot access Docker socket: $sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo "Socket permissions:" > /dev/stderr
|
||||
ls -l "$sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
|
||||
if [[ "$group" == "docker" ]]; then
|
||||
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
|
||||
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
|
||||
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
|
||||
echo " newgrp $group" > /dev/stderr
|
||||
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
|
||||
else
|
||||
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
|
||||
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
|
||||
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
|
||||
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
|
||||
fi
|
||||
|
||||
exit 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null
|
||||
then
|
||||
@@ -581,12 +621,15 @@ start_services_and_show_instructions() {
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
# Check if docker compose is installed using check_docker_compose function
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_docker_sock_perms
|
||||
|
||||
initialize_default_values
|
||||
configure_domain
|
||||
configure_reverse_proxy
|
||||
|
||||
check_jq
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
check_existing_installation
|
||||
generate_configuration_files
|
||||
|
||||
@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ const (
|
||||
RoleKey = nbcontext.RoleKey
|
||||
UserIDKey = nbcontext.UserIDKey
|
||||
PeerIDKey = nbcontext.PeerIDKey
|
||||
UserAgentKey = nbcontext.UserAgentKey
|
||||
)
|
||||
|
||||
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
|
||||
|
||||
@@ -21,6 +21,8 @@ const (
|
||||
httpRequestCounterPrefix = "management.http.request.counter"
|
||||
httpResponseCounterPrefix = "management.http.response.counter"
|
||||
httpRequestDurationPrefix = "management.http.request.duration.ms"
|
||||
|
||||
RequestIDHeader = "X-Request-Id"
|
||||
)
|
||||
|
||||
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
|
||||
@@ -172,6 +174,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
reqID := xid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
|
||||
|
||||
rw.Header().Set(RequestIDHeader, reqID)
|
||||
|
||||
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)
|
||||
|
||||
|
||||
@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
|
||||
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
|
||||
|
||||
rules := []*RouteFirewallRule{&rule}
|
||||
|
||||
if includeIPv6 && r.IsDynamic() {
|
||||
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
|
||||
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
if isDefaultV4 {
|
||||
ruleV6.Destination = "::/0"
|
||||
ruleV6.RouteID = r.ID + "-v6-default"
|
||||
}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
|
||||
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
|
||||
}
|
||||
|
||||
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
|
||||
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
|
||||
// (::/0 source and destination) for peers that support IPv6, mirroring the route
|
||||
// the client installs. Without it, IPv6 traffic is routed to the exit node but
|
||||
// dropped at the forward chain.
|
||||
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
|
||||
routingPeerID := "peer-5"
|
||||
routingPeer := account.Peers[routingPeerID]
|
||||
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
|
||||
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
|
||||
account.Routes["route-exit"] = &route.Route{
|
||||
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
PeerID: routingPeerID, Peer: routingPeer.Key,
|
||||
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
|
||||
AccessControlGroups: []string{},
|
||||
AccountID: "test-account",
|
||||
}
|
||||
|
||||
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, rfr := range nm.RoutesFirewallRules {
|
||||
switch rfr.Destination {
|
||||
case "0.0.0.0/0":
|
||||
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
|
||||
hasV4 = true
|
||||
}
|
||||
case "::/0":
|
||||
if slices.Contains(rfr.SourceRanges, "::/0") {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
|
||||
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// 15. MULTIPLE ROUTERS PER NETWORK
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -249,6 +249,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
|
||||
@@ -28,6 +28,10 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
const clientStopTimeout = 30 * time.Second
|
||||
|
||||
const createProxyPeerTimeout = 30 * time.Second
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey string
|
||||
|
||||
@@ -162,6 +166,7 @@ type NetBird struct {
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
lifecycleMu sync.Map
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
@@ -177,6 +182,10 @@ type NetBird struct {
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
|
||||
OnAddPeer func(d time.Duration, err error)
|
||||
|
||||
// startClient runs the post-create client startup. Nil uses runClientStartup;
|
||||
// tests override it to avoid a real embed client.Start.
|
||||
startClient func(accountID types.AccountID, client *embed.Client)
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
@@ -200,31 +209,20 @@ type skipTLSVerifyContextKey struct{}
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
// Use a background context, not the caller's: the management
|
||||
// connection notification must land even if the request /
|
||||
// stream that triggered this registration is cancelled.
|
||||
// Mirrors the async runClientStartup path.
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -234,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -246,17 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip. runClientStartup uses its
|
||||
// own background context so the caller's request-scoped ctx can't
|
||||
// cancel the inbound bring-up.
|
||||
go n.runClientStartup(accountID, entry.client)
|
||||
transferred = true
|
||||
go func() {
|
||||
defer lifecycle.Unlock()
|
||||
n.startClientStartup(accountID, entry.client)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
|
||||
if n.startClient != nil {
|
||||
n.startClient(accountID, client)
|
||||
return
|
||||
}
|
||||
n.runClientStartup(accountID, client)
|
||||
}
|
||||
|
||||
// registerExistingClient registers the service against an already-present
|
||||
// client for the account and returns true when it did. It notifies management
|
||||
// of the new service when the client is already started.
|
||||
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.Unlock()
|
||||
return false
|
||||
}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// accountLifecycle returns the per-account lifecycle mutex, serialising client
|
||||
// creation against teardown so a slow client.Stop cannot race a new
|
||||
// client.Start for the same account, without blocking clientsMux.
|
||||
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
|
||||
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
|
||||
return mu.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
// and creates an embedded NetBird client. Must be called with the account's
|
||||
// lifecycle mutex held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -276,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
|
||||
defer cancel()
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
@@ -444,6 +491,15 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -466,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -490,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
transferred = true
|
||||
go n.stopClientLocked(accountID, lifecycle, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopClientLocked releases a client's resources off the caller's goroutine so a
|
||||
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
|
||||
// synchronously). It unlocks lifecycle when done so a new client.Start for the
|
||||
// same account waits for this teardown.
|
||||
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
|
||||
defer lifecycle.Unlock()
|
||||
|
||||
if entry.inbound != nil && n.stopHandler != nil {
|
||||
n.stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
if entry.transport != nil {
|
||||
entry.transport.CloseIdleConnections()
|
||||
}
|
||||
if entry.insecureTransport != nil {
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
}
|
||||
if entry.client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -22,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
|
||||
// tests can detect AddPeer reaching client creation.
|
||||
type signalMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
m.once.Do(func() { close(m.entered) })
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
@@ -52,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, nil, &mockMgmtClient{})
|
||||
// Skip the real embed client.Start, which would hang against the unreachable
|
||||
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
return nb
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
@@ -288,6 +305,7 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first service — creates a new client entry.
|
||||
@@ -372,6 +390,117 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
|
||||
// stall: RemovePeer must return promptly even when the client teardown blocks,
|
||||
// because teardown runs off the caller's goroutine. The receive loop calls
|
||||
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
|
||||
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
|
||||
accountID := types.AccountID("acct-async-teardown")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
teardownEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
close(teardownEntered)
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-teardownEntered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("teardown never ran")
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
}
|
||||
|
||||
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
|
||||
// new client bringup behind an in-flight teardown for the same account, so a
|
||||
// slow client.Stop can never race a new client.Start for that account.
|
||||
//
|
||||
// It targets the handoff race specifically: AddPeer is launched immediately
|
||||
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
|
||||
// This only passes if RemovePeer acquires the lifecycle lock synchronously
|
||||
// (before returning) and hands it to the teardown goroutine — if the goroutine
|
||||
// acquired the lock itself, AddPeer could win the lock in this window and start
|
||||
// a replacement client while the old teardown is still pending.
|
||||
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
|
||||
accountID := types.AccountID("acct-serialize")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
addEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
// Block teardown until released. If AddPeer ever reaches createClientEntry
|
||||
// (signalled via the mgmt client below) while we hold the lock, the lock
|
||||
// failed to serialise and the test fails before we release.
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
|
||||
// AddPeer got past the lifecycle lock and into client creation.
|
||||
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
|
||||
|
||||
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
|
||||
|
||||
// Launch AddPeer with NO synchronisation against the teardown goroutine.
|
||||
addReturned := make(chan struct{})
|
||||
go func() {
|
||||
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
|
||||
close(addReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-addEntered:
|
||||
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
|
||||
case <-addReturned:
|
||||
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
select {
|
||||
case <-addReturned:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
|
||||
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
|
||||
// a fresh context.Background() rather than inheriting the AddPeer
|
||||
|
||||
@@ -114,6 +114,10 @@ type Config struct {
|
||||
MaxDialTimeout time.Duration
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// being applied before the receive loop reconnects to resync. Zero falls
|
||||
// back to the internal default.
|
||||
MappingBatchWatchdog time.Duration
|
||||
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files.
|
||||
GeoDataDir string
|
||||
@@ -164,6 +168,7 @@ func New(ctx context.Context, cfg Config) *Server {
|
||||
Private: cfg.Private,
|
||||
MaxDialTimeout: cfg.MaxDialTimeout,
|
||||
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
|
||||
GeoDataDir: cfg.GeoDataDir,
|
||||
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
|
||||
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
|
||||
|
||||
282
proxy/mapping_stall_test.go
Normal file
282
proxy/mapping_stall_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// blockingMgmtClient implements roundtrip's managementClient interface.
|
||||
// CreateProxyPeer parks until release is closed, signalling entry on entered.
|
||||
// This reproduces the confirmed real-world stall: createClientEntry calls
|
||||
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
|
||||
// receive loop calls that path synchronously inside processMappings.
|
||||
type blockingMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
b.once.Do(func() { close(b.entered) })
|
||||
// Park until the caller's context is cancelled. In production this ctx is
|
||||
// the gRPC mapping-stream context with no per-call timeout, so a slow or
|
||||
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
|
||||
// pre-seeded list of messages, then records how many times Recv advanced. It
|
||||
// lets the test observe whether the single-threaded receive loop ever gets
|
||||
// past the first (blocking) batch to pull the second message.
|
||||
type gatedMappingStream struct {
|
||||
grpc.ClientStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
idx int32
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
i := int(atomic.LoadInt32(&g.idx))
|
||||
if i >= len(g.messages) {
|
||||
// Block instead of returning EOF so the loop doesn't exit; we only
|
||||
// care whether the loop ever reaches this second Recv at all.
|
||||
select {}
|
||||
}
|
||||
msg := g.messages[i]
|
||||
atomic.AddInt32(&g.idx, 1)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
|
||||
|
||||
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
|
||||
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
|
||||
func (g *gatedMappingStream) CloseSend() error { return nil }
|
||||
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
|
||||
func (g *gatedMappingStream) SendMsg(any) error { return nil }
|
||||
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// noopNotifier satisfies roundtrip's statusNotifier interface.
|
||||
type noopNotifier struct{}
|
||||
|
||||
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
|
||||
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
|
||||
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
|
||||
// compile time; none of those methods are called by this test.
|
||||
type noopProxyClient struct {
|
||||
proto.ProxyServiceClient
|
||||
}
|
||||
|
||||
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
|
||||
// mapping receive loop processes batches strictly serially, so when applying
|
||||
// one batch blocks (here: createClientEntry parked on a synchronous
|
||||
// CreateProxyPeer call, exactly as observed in production), the loop never
|
||||
// advances to Recv the next batch. Management can keep sending updates onto
|
||||
// the stream with no error and no channel overflow, yet the proxy applies
|
||||
// nothing further — it is stuck.
|
||||
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
mgmt := &blockingMgmtClient{
|
||||
entered: make(chan struct{}),
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird(
|
||||
context.Background(),
|
||||
"proxy-test",
|
||||
"proxy.example.com",
|
||||
roundtrip.ClientConfig{},
|
||||
logger,
|
||||
noopNotifier{},
|
||||
mgmt,
|
||||
)
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
netbird: nb,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// First batch: a CREATED mapping for a brand-new account. addMapping ->
|
||||
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
|
||||
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
|
||||
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
|
||||
// no routers/auth need wiring. The second batch exists only to detect
|
||||
// whether the loop ever advances past the blocked first batch.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
AuthToken: "token-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-2",
|
||||
AccountId: "acct-2",
|
||||
AuthToken: "token-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
|
||||
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
|
||||
// avoiding any dependency on collaborators this test deliberately leaves
|
||||
// nil. The deadlock is fully proven before this fires.
|
||||
t.Cleanup(cancel)
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
// The loop must reach the blocking apply for the first batch.
|
||||
select {
|
||||
case <-mgmt.entered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
|
||||
// single-threaded loop cannot advance. The second batch is never pulled,
|
||||
// even though it is already available on the stream. Give it ample time.
|
||||
// deliveredCount is atomic; syncDone is intentionally not read here because
|
||||
// the loop goroutine owns it (reading it from the test would race).
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged in apply")
|
||||
default:
|
||||
// Still wedged, as expected.
|
||||
}
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
|
||||
// path observed in production: a mapping remove tears down the account's last
|
||||
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
|
||||
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
|
||||
// threaded, a blocked remove wedges the loop: no further mapping updates of any
|
||||
// kind (create/modify/remove) are applied, while management keeps sending them
|
||||
// successfully (no send error, no channel-full). Matches the reported symptom:
|
||||
// the last log line is a remove that stops a client, then silence.
|
||||
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
enteredRemove := make(chan struct{})
|
||||
blockRemove := make(chan struct{})
|
||||
var once sync.Once
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
// Stand in for netbird.RemovePeer -> client.Stop hanging on
|
||||
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
|
||||
// blocks; later removes return immediately so the recovery assertion
|
||||
// can observe the loop advancing.
|
||||
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
|
||||
first := false
|
||||
once.Do(func() {
|
||||
first = true
|
||||
close(enteredRemove)
|
||||
})
|
||||
if !first {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-blockRemove:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
|
||||
// that must never be applied while the remove is wedged.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-enteredRemove:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached the blocking remove for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
|
||||
// syncDone is owned by the loop goroutine, so it is not read here.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged on the remove")
|
||||
default:
|
||||
}
|
||||
|
||||
// Unblock and confirm the wedge was solely the blocked remove: the loop
|
||||
// then advances and consumes the next batch.
|
||||
close(blockRemove)
|
||||
assert.Eventually(t, func() bool {
|
||||
return stream.deliveredCount() >= 2
|
||||
}, 2*time.Second, 5*time.Millisecond,
|
||||
"once the remove unblocks, the loop must advance and consume the next batch")
|
||||
}
|
||||
@@ -118,6 +118,9 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// removePeer defaults to netbird.RemovePeer; overridable in tests.
|
||||
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
@@ -227,6 +230,10 @@ type Server struct {
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// in processMappings before the receive loop reconnects to resync.
|
||||
// Zero uses defaultMappingBatchWatchdog.
|
||||
MappingBatchWatchdog time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
@@ -1172,24 +1179,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
connected := false
|
||||
onConnected := func() { connected = true }
|
||||
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
// Reset backoff only when a stream actually connected, so immediate
|
||||
// connect failures still back off instead of spinning.
|
||||
if connected {
|
||||
bo.Reset()
|
||||
}
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
@@ -1221,7 +1234,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
@@ -1234,6 +1247,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1242,7 +1256,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
@@ -1263,6 +1277,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1307,7 +1322,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
@@ -1391,7 +1408,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
@@ -1456,6 +1475,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
return c
|
||||
}
|
||||
|
||||
const defaultMappingBatchWatchdog = 2 * time.Minute
|
||||
|
||||
// mappingBatchWatchdog returns the configured batch watchdog or the default.
|
||||
func (s *Server) mappingBatchWatchdog() time.Duration {
|
||||
if s.MappingBatchWatchdog > 0 {
|
||||
return s.MappingBatchWatchdog
|
||||
}
|
||||
return defaultMappingBatchWatchdog
|
||||
}
|
||||
|
||||
// processMappingsGuarded applies a batch under a watchdog, returning an error
|
||||
// if processing exceeds the watchdog so the caller reconnects and resyncs
|
||||
// instead of wedging silently.
|
||||
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
|
||||
batchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.processMappings(batchCtx, mappings)
|
||||
}()
|
||||
|
||||
watchdog := s.mappingBatchWatchdog()
|
||||
timer := time.NewTimer(watchdog)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
|
||||
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
@@ -1951,7 +2008,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
svcKey := s.serviceKeyForMapping(mapping)
|
||||
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
|
||||
removePeer := s.removePeer
|
||||
if removePeer == nil {
|
||||
removePeer = s.netbird.RemovePeer
|
||||
}
|
||||
if err := removePeer(ctx, accountID, svcKey); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": mapping.GetId(),
|
||||
|
||||
@@ -417,15 +417,30 @@ if type uname >/dev/null 2>&1; then
|
||||
# Check the availability of a compatible package manager
|
||||
if check_use_bin_variable; then
|
||||
PACKAGE_MANAGER="bin"
|
||||
elif [ -e /run/ostree-booted ]; then
|
||||
if [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v bootc)" ]; then
|
||||
echo "Detected bootc system without rpm-ostree." >&2
|
||||
echo "NetBird cannot be installed via package manager on this system." >&2
|
||||
echo "Options:" >&2
|
||||
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
|
||||
echo " 2. Rebuild your base image with rpm-ostree included" >&2
|
||||
echo " 3. Bake NetBird into your Containerfile" >&2
|
||||
exit 1
|
||||
else
|
||||
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
|
||||
echo "NetBird cannot be installed automatically on this atomic system." >&2
|
||||
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
|
||||
exit 1
|
||||
fi
|
||||
elif [ -x "$(command -v apt-get)" ]; then
|
||||
PACKAGE_MANAGER="apt"
|
||||
echo "The installation will be performed using apt package manager"
|
||||
elif [ -x "$(command -v dnf)" ]; then
|
||||
PACKAGE_MANAGER="dnf"
|
||||
echo "The installation will be performed using dnf package manager"
|
||||
elif [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v yum)" ]; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
echo "The installation will be performed using yum package manager"
|
||||
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
RoleKey = "role"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
UserAgentKey = "userAgent"
|
||||
)
|
||||
|
||||
@@ -5107,31 +5107,63 @@ components:
|
||||
responses:
|
||||
not_found:
|
||||
description: Resource not found
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed_simple:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
bad_request:
|
||||
description: Bad Request
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
internal_error:
|
||||
description: Internal Server Error
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
forbidden:
|
||||
description: Forbidden
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
requires_authentication:
|
||||
description: Requires authentication
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
conflict:
|
||||
description: Conflict
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
headers:
|
||||
X-Request-Id:
|
||||
description: |
|
||||
Unique identifier assigned to the request by the server and set on every
|
||||
response. Useful for correlating client requests with server-side logs.
|
||||
schema:
|
||||
type: string
|
||||
example: cot7r4n3l3vh3qj4qveg
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
|
||||
@@ -9,12 +9,14 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
@@ -172,6 +174,19 @@ type Client struct {
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
|
||||
// transportFallback, when set, records datagram-too-large failures so a
|
||||
// datagram-sized transport is avoided on subsequent connects. Shared via
|
||||
// the manager.
|
||||
transportFallback *transportFallback
|
||||
// datagramFallbackTriggered guards a single fallback per connection so a
|
||||
// burst of oversized datagrams triggers one reconnect, not many.
|
||||
datagramFallbackTriggered atomic.Bool
|
||||
}
|
||||
|
||||
// SetTransportFallback wires the shared datagram-transport fallback tracker.
|
||||
func (c *Client) SetTransportFallback(tf *transportFallback) {
|
||||
c.transportFallback = tf
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
@@ -361,12 +376,13 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
mode := transportModeFromEnv()
|
||||
dialers := c.getDialers(mode)
|
||||
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
conn, err = c.dialRaceDirect(ctx, mode, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
@@ -375,6 +391,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
@@ -382,6 +401,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
c.datagramFallbackTriggered.Store(false)
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
@@ -396,7 +416,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
@@ -406,6 +426,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
@@ -631,13 +654,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
conn := c.relayConn
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
if errors.Is(err, netErr.ErrDatagramTooLarge) {
|
||||
c.onDatagramTooLarge(conn, err)
|
||||
} else {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
|
||||
// When a non-datagram transport is available, it records a fallback for this
|
||||
// server and closes the connection so the reconnect avoids datagram-sized
|
||||
// transports. A single fallback is triggered per connection regardless of how
|
||||
// many oversized datagrams arrive. cause carries the datagram size and budget.
|
||||
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
|
||||
// Handle one oversized datagram per connection; a burst triggers a single
|
||||
// fallback (and a single log line), not many.
|
||||
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the selected mode offers no non-datagram transport (e.g. pinned to a
|
||||
// datagram-sized transport), reconnecting would just re-fail, so leave the
|
||||
// connection up rather than loop.
|
||||
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
|
||||
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
|
||||
return
|
||||
}
|
||||
|
||||
// Without the shared tracker a reconnect would just select the same
|
||||
// transport again and re-fail, so leave the connection up rather than loop.
|
||||
if c.transportFallback == nil {
|
||||
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
|
||||
return
|
||||
}
|
||||
|
||||
window := c.transportFallback.recordFailure(c.connectionURL)
|
||||
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.Debugf("close relay connection for transport fallback: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
|
||||
18
shared/relay/client/dialer/capability.go
Normal file
18
shared/relay/client/dialer/capability.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dialer
|
||||
|
||||
// DatagramSized is implemented by dialers whose connections carry each write in
|
||||
// a single datagram, so a write can be rejected when it exceeds the path's
|
||||
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
|
||||
// WebSocket over TCP) impose no per-write size limit, so the relay client can
|
||||
// fall back to them when a datagram-sized transport rejects a write as too
|
||||
// large. The capability is advertised per dialer rather than hardcoded, so a
|
||||
// new transport only needs to declare whether it is datagram-sized.
|
||||
type DatagramSized interface {
|
||||
DatagramSized()
|
||||
}
|
||||
|
||||
// IsDatagramSized reports whether d produces datagram-sized connections.
|
||||
func IsDatagramSized(d DialeFn) bool {
|
||||
_, ok := d.(DatagramSized)
|
||||
return ok
|
||||
}
|
||||
@@ -4,4 +4,9 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
|
||||
// ErrDatagramTooLarge is returned when a transport message exceeds the
|
||||
// QUIC datagram size the path to the relay can carry. The relay client
|
||||
// treats it as a signal to fall back to a non-datagram transport.
|
||||
ErrDatagramTooLarge = errors.New("datagram frame too large")
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
@@ -52,11 +51,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.writeErrHandling(err, len(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -95,3 +91,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
|
||||
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
|
||||
// datagram size and path budget) so the relay client can fall back to a
|
||||
// non-datagram transport.
|
||||
func (c *Conn) writeErrHandling(err error, size int) error {
|
||||
var tooLarge *quic.DatagramTooLargeError
|
||||
if errors.As(err, &tooLarge) {
|
||||
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
|
||||
}
|
||||
return c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -23,6 +24,12 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
|
||||
// carried in QUIC DATAGRAM frames, which must fit a single packet.
|
||||
func (d Dialer) DatagramSized() {
|
||||
// Intentional marker method; presence is the capability signal.
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
@@ -47,6 +54,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
Tracer: connectionTracer(quicURL),
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
@@ -74,6 +82,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
|
||||
// reason a relay connection closed, so the path MTU settled on and teardown
|
||||
// cause are visible in logs. Lines carry the relay address as a structured
|
||||
// field, matching the rest of the relay client logging.
|
||||
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
relayLog := log.WithField("relay", addr)
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
|
||||
if done {
|
||||
relayLog.Infof("QUIC path MTU settled at %d", mtu)
|
||||
return
|
||||
}
|
||||
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
|
||||
},
|
||||
ClosedConnection: func(err error) {
|
||||
relayLog.Debugf("QUIC connection closed: %v", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
@@ -32,6 +32,7 @@ type RaceDial struct {
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
sequential bool
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
@@ -53,7 +54,21 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
return r
|
||||
}
|
||||
|
||||
// WithSequential makes Dial try the dialers in order, falling back to the next
|
||||
// only when one fails to connect, instead of racing them concurrently.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithSequential() *RaceDial {
|
||||
r.sequential = true
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
if r.sequential {
|
||||
return r.dialSequential(ctx)
|
||||
}
|
||||
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(ctx)
|
||||
@@ -72,6 +87,30 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialSequential tries each dialer in order, returning the first connection and
|
||||
// falling back to the next on failure.
|
||||
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
|
||||
continue
|
||||
}
|
||||
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
|
||||
return conn, nil
|
||||
}
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialFallback(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
var firstDialed, secondDialed bool
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
firstDialed = true
|
||||
return nil, errors.New("quic unreachable")
|
||||
},
|
||||
}
|
||||
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
secondDialed = true
|
||||
return fallbackConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got %v", err)
|
||||
}
|
||||
if conn != fallbackConn {
|
||||
t.Errorf("expected fallback connection, got %v", conn)
|
||||
}
|
||||
if !firstDialed || !secondDialed {
|
||||
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialPreferredWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return preferredConn, nil
|
||||
},
|
||||
}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
t.Errorf("fallback dialer must not be tried when preferred succeeds")
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected preferred to succeed, got %v", err)
|
||||
}
|
||||
if conn != preferredConn {
|
||||
t.Errorf("expected preferred connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,42 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// getDialers returns the list of dialers to use for connecting to the relay server.
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
// getDialers returns the ordered dialers for connecting to the relay server. It
|
||||
// applies the datagram fallback generically: if this server recently rejected a
|
||||
// datagram-sized transport, those dialers are dropped, leaving the rest.
|
||||
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
|
||||
dialers := c.baseDialers(mode)
|
||||
|
||||
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
|
||||
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
|
||||
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
|
||||
return filtered
|
||||
}
|
||||
}
|
||||
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
return dialers
|
||||
}
|
||||
|
||||
// baseDialers returns the ordered dialers for the mode, before any datagram
|
||||
// fallback filtering. For racing modes (auto) the order is irrelevant; for
|
||||
// prefer modes the first entry is tried before falling back to the second.
|
||||
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
|
||||
switch mode {
|
||||
case TransportModeWS:
|
||||
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
case TransportModeQUIC:
|
||||
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{quic.Dialer{}}
|
||||
}
|
||||
|
||||
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
if mode == TransportModePreferWS {
|
||||
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
|
||||
}
|
||||
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
|
||||
return nonDatagramSized(all)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
101
shared/relay/client/dialers_generic_test.go
Normal file
101
shared/relay/client/dialers_generic_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build !js
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// TestDatagramSizedCapability locks the capability the generic fallback relies
|
||||
// on: QUIC is datagram-sized, WebSocket is not.
|
||||
func TestDatagramSizedCapability(t *testing.T) {
|
||||
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
|
||||
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
|
||||
}
|
||||
|
||||
func protocols(dialers []dialer.DialeFn) []string {
|
||||
out := make([]string, len(dialers))
|
||||
for i, d := range dialers {
|
||||
out[i] = d.Protocol()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestGetDialers(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
mtu uint16
|
||||
preferWS bool
|
||||
want []string
|
||||
}{
|
||||
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
|
||||
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
|
||||
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
|
||||
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.mode)
|
||||
if tc.mode == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
|
||||
tf := newTransportFallback()
|
||||
if tc.preferWS {
|
||||
tf.recordFailure(url)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: tc.mtu,
|
||||
transportFallback: tf,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
|
||||
// datagram records a fallback that makes the next dial pick WebSocket, the way a
|
||||
// reconnect would after the connection is closed.
|
||||
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: iface.DefaultMTU,
|
||||
transportFallback: newTransportFallback(),
|
||||
}
|
||||
|
||||
// First dial races both transports.
|
||||
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
|
||||
// An oversized datagram records the fallback for this server.
|
||||
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
|
||||
|
||||
// The reconnect now sticks to WebSocket.
|
||||
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
}
|
||||
@@ -7,7 +7,11 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
|
||||
// JS/WASM build only uses WebSocket transport
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -79,23 +79,30 @@ type Manager struct {
|
||||
|
||||
cleanupInterval time.Duration
|
||||
keepUnusedServerTime time.Duration
|
||||
|
||||
// transportFallback is shared across home and foreign relay clients so a
|
||||
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
|
||||
transportFallback *transportFallback
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
tf := newTransportFallback()
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
transportFallback: tf,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TransportFallback: tf,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
@@ -287,6 +294,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient.SetTransportFallback(m.transportFallback)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
|
||||
@@ -29,6 +29,7 @@ type ServerPicker struct {
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
TransportFallback *transportFallback
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
@@ -70,6 +71,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
relayClient.SetTransportFallback(sp.TransportFallback)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
|
||||
129
shared/relay/client/transport.go
Normal file
129
shared/relay/client/transport.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
)
|
||||
|
||||
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
|
||||
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
|
||||
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
|
||||
// the other only if it fails to connect; no race). The prefer modes trade a
|
||||
// slower connect when the preferred transport is blackholed for deterministic
|
||||
// transport selection.
|
||||
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
|
||||
|
||||
const (
|
||||
// transportFallbackBase is the initial window a relay server avoids
|
||||
// datagram-sized transports after a datagram is rejected as too large.
|
||||
transportFallbackBase = 10 * time.Minute
|
||||
// transportFallbackMax caps the pinned window when failures repeat.
|
||||
transportFallbackMax = 60 * time.Minute
|
||||
)
|
||||
|
||||
// TransportMode selects which relay dialers are used.
|
||||
type TransportMode string
|
||||
|
||||
const (
|
||||
TransportModeAuto TransportMode = "auto"
|
||||
TransportModeQUIC TransportMode = "quic"
|
||||
TransportModeWS TransportMode = "ws"
|
||||
TransportModePreferQUIC TransportMode = "prefer-quic"
|
||||
TransportModePreferWS TransportMode = "prefer-ws"
|
||||
)
|
||||
|
||||
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
|
||||
// or unrecognized value.
|
||||
func transportModeFromEnv() TransportMode {
|
||||
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
|
||||
case "", TransportModeAuto:
|
||||
return TransportModeAuto
|
||||
case TransportModeQUIC:
|
||||
return TransportModeQUIC
|
||||
case TransportModeWS:
|
||||
return TransportModeWS
|
||||
case TransportModePreferQUIC:
|
||||
return TransportModePreferQUIC
|
||||
case TransportModePreferWS:
|
||||
return TransportModePreferWS
|
||||
default:
|
||||
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
|
||||
return TransportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// sequential reports whether the mode tries dialers in order with fallback
|
||||
// instead of racing them concurrently.
|
||||
func (m TransportMode) sequential() bool {
|
||||
return m == TransportModePreferQUIC || m == TransportModePreferWS
|
||||
}
|
||||
|
||||
// transportFallback tracks relay servers that have rejected a datagram-sized
|
||||
// transport (a write too large for the path) and should temporarily avoid such
|
||||
// transports. It is shared across the relay manager so the preference survives
|
||||
// client recreation (foreign relay clients are evicted and rebuilt on
|
||||
// disconnect). Entries are keyed by server URL and expire after a window that
|
||||
// grows on repeated failures.
|
||||
type transportFallback struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*fallbackEntry
|
||||
}
|
||||
|
||||
type fallbackEntry struct {
|
||||
until time.Time
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func newTransportFallback() *transportFallback {
|
||||
return &transportFallback{entries: make(map[string]*fallbackEntry)}
|
||||
}
|
||||
|
||||
// avoidDatagramSized reports whether serverURL is currently within a window
|
||||
// where datagram-sized transports should be avoided.
|
||||
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
e := f.entries[serverURL]
|
||||
return e != nil && time.Now().Before(e.until)
|
||||
}
|
||||
|
||||
// recordFailure makes serverURL avoid datagram-sized transports for a window:
|
||||
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
|
||||
// when a datagram transport fails again after a previous window expired. It
|
||||
// returns the active window duration.
|
||||
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
e := f.entries[serverURL]
|
||||
switch {
|
||||
case e == nil:
|
||||
e = &fallbackEntry{duration: transportFallbackBase}
|
||||
f.entries[serverURL] = e
|
||||
case now.Before(e.until):
|
||||
return time.Until(e.until)
|
||||
default:
|
||||
e.duration = min(e.duration*2, transportFallbackMax)
|
||||
}
|
||||
e.until = now.Add(e.duration)
|
||||
return e.duration
|
||||
}
|
||||
|
||||
// nonDatagramSized returns the dialers from in that are not datagram-sized,
|
||||
// preserving order.
|
||||
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
|
||||
out := make([]dialer.DialeFn, 0, len(in))
|
||||
for _, d := range in {
|
||||
if !dialer.IsDatagramSized(d) {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
140
shared/relay/client/transport_test.go
Normal file
140
shared/relay/client/transport_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
// closeTrackingConn records whether Close was called; only Close is exercised.
|
||||
type closeTrackingConn struct {
|
||||
net.Conn
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *closeTrackingConn) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTransportModeFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
want TransportMode
|
||||
}{
|
||||
{"", TransportModeAuto},
|
||||
{"auto", TransportModeAuto},
|
||||
{"quic", TransportModeQUIC},
|
||||
{"QUIC", TransportModeQUIC},
|
||||
{"ws", TransportModeWS},
|
||||
{" Ws ", TransportModeWS},
|
||||
{"prefer-quic", TransportModePreferQUIC},
|
||||
{"prefer-ws", TransportModePreferWS},
|
||||
{"garbage", TransportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.value, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.value)
|
||||
if tc.value == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
assert.Equal(t, tc.want, transportModeFromEnv())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
|
||||
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
|
||||
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
|
||||
|
||||
// A second failure while still inside the window must not grow the window.
|
||||
d = f.recordFailure(url)
|
||||
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
|
||||
require.NotNil(t, f.entries[url])
|
||||
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
|
||||
|
||||
// Expire the window: datagram-sized transport allowed again.
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
|
||||
}
|
||||
|
||||
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
want := transportFallbackBase
|
||||
for i := range 6 {
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, want, d, "window after %d expiries", i)
|
||||
|
||||
// expire the window so the next failure is treated as a repeat
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
|
||||
want = min(want*2, transportFallbackMax)
|
||||
}
|
||||
|
||||
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeAuto(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.True(t, conn.closed, "connection closed to force reconnect")
|
||||
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
|
||||
|
||||
// A second oversized datagram on the same connection must not re-close.
|
||||
conn.closed = false
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
assert.False(t, conn.closed, "single fallback per connection")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
|
||||
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
|
||||
}
|
||||
|
||||
func TestTransportFallbackPerServer(t *testing.T) {
|
||||
f := newTransportFallback()
|
||||
f.recordFailure("rels://a.example:443")
|
||||
|
||||
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
|
||||
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
|
||||
}
|
||||
Reference in New Issue
Block a user