Compare commits

..

7 Commits

Author SHA1 Message Date
bcmmbaga
6240dcd96a add logs 2026-05-08 13:57:41 +03:00
Viktor Liu
f532976e05 [client] Add public key to debug bundle config.txt (#6092) 2026-05-06 13:42:47 +02:00
Viktor Liu
71a400f90f [client] Include MTU and SSH auth/JWT cache config in debug bundle (#6071) 2026-05-06 13:23:43 +02:00
Pascal Fischer
bfeb9b19ec [management] remove permissions from geolocations api (#6091) 2026-05-06 13:07:01 +02:00
Pascal Fischer
b19b7464ea [management] fix flaky invite token test (#6077) 2026-05-05 18:48:51 +02:00
Pascal Fischer
cfb1b3fe31 [proxy] consolidate mapping update (#6072) 2026-05-05 18:40:42 +02:00
Bethuel Mmbaga
3c28d29725 [management] Map Entra oid claim as Dex user ID (#6067) 2026-05-05 18:12:18 +03:00
31 changed files with 1069 additions and 1762 deletions

View File

@@ -92,9 +92,6 @@ linters:
- linters: - linters:
- unused - unused
path: client/firewall/iptables/rule\.go path: client/firewall/iptables/rule\.go
- linters:
- unused
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
- linters: - linters:
- gosec - gosec
- mirror - mirror

View File

@@ -21,6 +21,7 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
@@ -583,6 +584,9 @@ func isSensitiveEnvVar(key string) bool {
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n") configContent.WriteString("NetBird Client Configuration:\n\n")
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
}
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface)) configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort)) configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
if g.internalConfig.NetworkMonitor != nil { if g.internalConfig.NetworkMonitor != nil {
@@ -607,6 +611,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil { if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
} }
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -633,6 +643,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
} }
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {

View File

@@ -5,16 +5,21 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"HOME": "/root", "HOME": "/root",
"PATH": "/usr/bin", "PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug", "NB_LOG_LEVEL": "debug",
}, },
}, },
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123", "NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz", "NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info", "NB_LOG_LEVEL": "info",
}, },
}, },
check: func(t *testing.T, params map[string]any) { check: func(t *testing.T, params map[string]any) {
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
} }
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -1,63 +0,0 @@
package dnsfw
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
// EnvPorts overrides the comma-separated list of remote ports to block.
// Empty disables the firewall.
EnvPorts = "NB_DNS_FIREWALL_PORTS"
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
// and the netbird daemon. Default mode also permits anything on the
// netbird tunnel interface, which is safer if NRPT is silently ignored
// by Windows but lets apps reach custom DNS servers via the tunnel.
EnvStrict = "NB_DNS_FIREWALL_STRICT"
)
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
var defaultBlockedPorts = []uint16{53, 853}
// blockedPorts returns the effective port list, honoring env overrides.
// A nil return means the firewall should not be installed.
func blockedPorts() []uint16 {
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
log.Infof("dns firewall disabled via %s", EnvDisable)
return nil
}
override, ok := os.LookupEnv(EnvPorts)
if !ok {
return defaultBlockedPorts
}
var ports []uint16
for _, raw := range strings.Split(override, ",") {
raw = strings.TrimSpace(raw)
if raw == "" {
continue
}
port, err := strconv.ParseUint(raw, 10, 16)
if err != nil {
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
continue
}
if port == 0 {
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
continue
}
ports = append(ports, uint16(port))
}
if len(ports) == 0 {
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
return nil
}
return ports
}

View File

@@ -1,39 +0,0 @@
package dnsfw
import (
"reflect"
"testing"
)
func TestBlockedPorts(t *testing.T) {
tests := []struct {
name string
disable string
ports string
setPorts bool
want []uint16
}{
{name: "default", want: defaultBlockedPorts},
{name: "disabled", disable: "true", want: nil},
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
{name: "override empty disables", ports: "", setPorts: true, want: nil},
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvDisable, tc.disable)
if tc.setPorts {
t.Setenv(EnvPorts, tc.ports)
}
got := blockedPorts()
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
}
})
}
}

View File

@@ -1,16 +0,0 @@
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
// managing the host's DNS, so that resolvers running on apps or libraries
// outside netbird cannot bypass the configured DNS path.
//
// Implementation is Windows-only (uses WFP). On other platforms New returns
// a no-op manager.
package dnsfw
import "net/netip"
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
// to call multiple times.
type Manager interface {
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
Disable() error
}

View File

@@ -1,15 +0,0 @@
//go:build !windows
package dnsfw
import "net/netip"
type noopManager struct{}
func (noopManager) Enable(string, netip.Addr) error { return nil }
func (noopManager) Disable() error { return nil }
// New returns a no-op manager on non-Windows platforms.
func New() Manager {
return noopManager{}
}

View File

@@ -1,144 +0,0 @@
//go:build windows
package dnsfw
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
)
type windowsManager struct {
mu sync.Mutex
// session is the WFP engine handle. Zero when disabled.
session uintptr
}
// Enable installs the dns firewall. Strict mode propagates failures;
// non-strict mode logs and returns nil so partial protection is preserved.
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
m.mu.Lock()
defer m.mu.Unlock()
ports := blockedPorts()
if len(ports) == 0 {
return nil
}
if m.session != 0 {
if err := m.disableLocked(); err != nil {
return fmt.Errorf("reset existing dns firewall session: %w", err)
}
}
strict := strictMode()
luid, err := luidFromGUID(ifaceGUID)
if err != nil {
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
}
exe, err := os.Executable()
if err != nil {
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
}
cfg := installConfig{
tunLUID: luid,
daemonExe: exe,
blockedPorts: ports,
strict: strict,
virtualDNSIP: virtualDNSIP,
}
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
session, installErr := installFilters(cfg)
if session == 0 {
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
}
if installErr != nil && strict {
_ = closeSession(session)
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
}
m.session = session
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
ifaceGUID, exe, ports, strict, virtualDNSIP)
if installErr != nil {
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
}
return nil
}
func (m *windowsManager) Disable() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.disableLocked()
}
func (m *windowsManager) disableLocked() error {
if m.session == 0 {
return nil
}
session := m.session
m.session = 0
if err := closeSession(session); err != nil {
return fmt.Errorf("close wfp session: %w", err)
}
log.Info("dns firewall removed")
return nil
}
// failOrLog returns err unchanged in strict mode. In non-strict mode the
// error is logged and nil is returned.
func (m *windowsManager) failOrLog(strict bool, err error) error {
if strict {
return err
}
log.Errorf("dns firewall: %v", err)
return nil
}
// New returns a Windows DNS firewall manager backed by WFP.
func New() Manager {
return &windowsManager{}
}
// strictMode reports whether strict mode is enabled via env.
func strictMode() bool {
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
return v
}
// luidFromGUID converts a Windows interface GUID string to its LUID.
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in luidFromGUID: %v", r)
}
}()
guid, err := windows.GUIDFromString(ifaceGUID)
if err != nil {
return 0, fmt.Errorf("parse guid: %w", err)
}
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
uintptr(unsafe.Pointer(&guid)),
uintptr(unsafe.Pointer(&luid)),
)
if rc != 0 {
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
}
return luid, nil
}

View File

@@ -1,72 +0,0 @@
//go:build windows
package dnsfw
import (
"net/netip"
"os"
"testing"
)
func TestStrictMode(t *testing.T) {
tests := []struct {
name string
val string
set bool
want bool
}{
{name: "unset", want: false},
{name: "true", val: "true", set: true, want: true},
{name: "1", val: "1", set: true, want: true},
{name: "false", val: "false", set: true, want: false},
{name: "invalid is false", val: "garbage", set: true, want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvStrict, tc.val)
if !tc.set {
os.Unsetenv(EnvStrict)
}
if got := strictMode(); got != tc.want {
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
}
})
}
}
func TestWindowsManagerDisableIdempotent(t *testing.T) {
m := &windowsManager{}
if err := m.Disable(); err != nil {
t.Fatalf("first Disable on fresh manager: %v", err)
}
if err := m.Disable(); err != nil {
t.Fatalf("second Disable on fresh manager: %v", err)
}
if m.session != 0 {
t.Fatalf("session should remain zero, got %d", m.session)
}
}
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
t.Setenv(EnvDisable, "true")
m := &windowsManager{}
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
}
if m.session != 0 {
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
}
}
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
t.Setenv(EnvPorts, "")
m := &windowsManager{}
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
}
if m.session != 0 {
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
}
}

View File

@@ -1,53 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
*/
package dnsfw
import (
"errors"
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/windows"
)
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
namePtr, err := windows.UTF16PtrFromString(name)
if err != nil {
return nil, wrapErr(err)
}
descriptionPtr, err := windows.UTF16PtrFromString(description)
if err != nil {
return nil, wrapErr(err)
}
return &wtFwpmDisplayData0{
name: namePtr,
description: descriptionPtr,
}, nil
}
func filterWeight(weight uint8) wtFwpValue0 {
return wtFwpValue0{
_type: cFWP_UINT8,
value: uintptr(weight),
}
}
func wrapErr(err error) error {
var errno syscall.Errno
if !errors.As(err, &errno) {
return err
}
_, file, line, ok := runtime.Caller(1)
if !ok {
return fmt.Errorf("wfp error at unknown location: %w", err)
}
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
}

View File

@@ -1,249 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
*
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
* uses for its kill-switch DNS leak protection. We extend it with a
* configurable port set so we also cover :853 (DoT) and any future ports.
*/
package dnsfw
import (
"encoding/binary"
"fmt"
"net/netip"
"unsafe"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/windows"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
// follow the authorized outbound flow.
// permitTunInterface installs a permit filter for any traffic whose local
// interface is the netbird tunnel.
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
cond := wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT64,
value: uintptr(unsafe.Pointer(&ifLUID)),
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: 1,
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
}
// permitDaemonByAppID installs a permit filter matching the netbird daemon
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
// dedicated binary.
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
appID, err := daemonAppID(daemonExe)
if err != nil {
return err
}
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
cond := wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_BYTE_BLOB_TYPE,
value: uintptr(unsafe.Pointer(appID)),
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: 1,
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
return addOutboundFilters(session, &filter, "Permit netbird daemon")
}
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
// permitTunInterface.
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
var merr *multierror.Error
for _, port := range ports {
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
if !ip.IsValid() {
return fmt.Errorf("invalid address")
}
var addrCond wtFwpmFilterCondition0
var layer windows.GUID
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
var v6 wtFwpByteArray16
if ip.Is4() {
v4 := ip.As4()
addrCond = wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT32,
value: uintptr(binary.BigEndian.Uint32(v4[:])),
},
}
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
} else {
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
addrCond = wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_BYTE_ARRAY16_TYPE,
value: uintptr(unsafe.Pointer(&v6)),
},
}
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
}
conditions := [2]wtFwpmFilterCondition0{
addrCond,
{
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT16,
value: uintptr(port),
},
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: uint32(len(conditions)),
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
if err != nil {
return wrapErr(err)
}
filter.displayData = *display
filter.layerKey = layer
var filterID uint64
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
return wrapErr(err)
}
_ = v6
return nil
}
// blockDNSPorts installs a deny filter for outbound traffic to each of the
// given remote ports over UDP or TCP. Per-port and per-layer failures are
// accumulated; partial coverage is preferred over zero coverage.
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
var merr *multierror.Error
for _, port := range ports {
if err := blockDNSPort(session, base, port, weight); err != nil {
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
conditions := [3]wtFwpmFilterCondition0{
{
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT16,
value: uintptr(port),
},
},
{
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT8,
value: uintptr(cIPPROTO_UDP),
},
},
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
{
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT8,
value: uintptr(cIPPROTO_TCP),
},
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: uint32(len(conditions)),
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
}
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
}
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
// connect layers. v4 and v6 are installed independently: failure on one
// layer does not abort the other, and the accumulated errors are returned.
// Partial coverage is preferred over zero coverage.
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
layers := [...]struct {
layer windows.GUID
label string
}{
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
}
var merr *multierror.Error
for _, l := range layers {
display, err := createWtFwpmDisplayData0(l.label, "")
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
continue
}
filter.displayData = *display
filter.layerKey = l.layer
var filterID uint64
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,177 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
*
* Session lifecycle and the high-level Install/Close entry points adapted
* from wireguard-windows tunnel/firewall.
*/
package dnsfw
import (
"errors"
"fmt"
"net/netip"
"unsafe"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/windows"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// installConfig is the input to installFilters.
type installConfig struct {
tunLUID uint64
daemonExe string
blockedPorts []uint16
// strict, when true, narrows the carve-out from "anything on tun" to
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
strict bool
virtualDNSIP netip.Addr
}
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
// for our session. Both are randomly generated per session.
type baseObjects struct {
provider windows.GUID
filters windows.GUID
}
// installFilters opens a dynamic WFP session and installs the netbird DNS
// firewall filters. Returns a zero session on hard failure (session create,
// base objects); a non-zero session with a non-nil error is a partial install
// (some per-filter installs failed) and is safe to close.
func installFilters(cfg installConfig) (session uintptr, err error) {
defer func() {
if r := recover(); r != nil {
// Dynamic session: kernel will clean up on process exit even
// if we leave the handle dangling here.
err = fmt.Errorf("panic in installFilters: %v", r)
}
}()
if len(cfg.blockedPorts) == 0 {
return 0, errors.New("dns firewall: no blocked ports configured")
}
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
}
session, err = createSession()
if err != nil {
return 0, err
}
base, err := registerBaseObjects(session)
if err != nil {
_ = fwpmEngineClose0(session)
return 0, fmt.Errorf("register base objects: %w", err)
}
var merr *multierror.Error
if cfg.strict {
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
}
} else {
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
}
}
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
}
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
}
return session, nberrors.FormatErrorOrNil(merr)
}
// closeSession tears down a WFP session previously opened by installFilters.
// All filters owned by the session are removed.
func closeSession(session uintptr) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in closeSession: %v", r)
}
}()
if session == 0 {
return nil
}
if err := fwpmEngineClose0(session); err != nil {
return wrapErr(err)
}
return nil
}
func createSession() (uintptr, error) {
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
if err != nil {
return 0, wrapErr(err)
}
session := wtFwpmSession0{
displayData: *displayData,
flags: cFWPM_SESSION_FLAG_DYNAMIC,
txnWaitTimeoutInMSec: windows.INFINITE,
}
var handle uintptr
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
return 0, wrapErr(err)
}
return handle, nil
}
func registerBaseObjects(session uintptr) (*baseObjects, error) {
bo := &baseObjects{}
var err error
if bo.provider, err = windows.GenerateGUID(); err != nil {
return nil, wrapErr(err)
}
if bo.filters, err = windows.GenerateGUID(); err != nil {
return nil, wrapErr(err)
}
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
if err != nil {
return nil, wrapErr(err)
}
provider := wtFwpmProvider0{
providerKey: bo.provider,
displayData: *displayData,
}
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
return nil, wrapErr(err)
}
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
if err != nil {
return nil, wrapErr(err)
}
sublayer := wtFwpmSublayer0{
subLayerKey: bo.filters,
displayData: *subDisplay,
providerKey: &bo.provider,
weight: ^uint16(0),
}
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
return nil, wrapErr(err)
}
return bo, nil
}
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
func daemonAppID(path string) (*wtFwpByteBlob, error) {
pathPtr, err := windows.UTF16PtrFromString(path)
if err != nil {
return nil, wrapErr(err)
}
var appID *wtFwpByteBlob
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
return nil, wrapErr(err)
}
return appID, nil
}

View File

@@ -1,38 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
*/
package dnsfw
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0

View File

@@ -1,414 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
wtFwpBitmapArray64_Size = 8
wtFwpByteArray16_Size = 16
wtFwpByteArray6_Size = 6
wtFwpmAction0_Size = 20
wtFwpmAction0_filterType_Offset = 4
wtFwpV4AddrAndMask_Size = 8
wtFwpV4AddrAndMask_mask_Offset = 4
wtFwpV6AddrAndMask_Size = 17
wtFwpV6AddrAndMask_prefixLength_Offset = 16
)
type wtFwpActionFlag uint32
const (
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
)
// FWP_ACTION_TYPE defined in fwptypes.h
type wtFwpActionType uint32
const (
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
)
// FWP_BYTE_BLOB defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
type wtFwpByteBlob struct {
size uint32
data *uint8
}
// FWP_MATCH_TYPE defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
type wtFwpMatchType uint32
const (
cFWP_MATCH_EQUAL wtFwpMatchType = 0
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
)
// FWPM_ACTION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
type wtFwpmAction0 struct {
_type wtFwpActionType
filterType windows.GUID // Windows type: GUID
}
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
Data1: 0x4cd62a49,
Data2: 0x59c3,
Data3: 0x4969,
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
}
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
Data1: 0xb235ae9a,
Data2: 0x1d64,
Data3: 0x49b8,
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
}
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
Data1: 0x3971ef2b,
Data2: 0x623e,
Data3: 0x4f9a,
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
}
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
Data1: 0x0c1ba1af,
Data2: 0x5765,
Data3: 0x453f,
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
}
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
Data1: 0xc35a604d,
Data2: 0xd22b,
Data3: 0x4e1a,
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
}
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
Data1: 0xd78e1e87,
Data2: 0x8644,
Data3: 0x4ea5,
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
}
// af043a0a-b34d-4f86-979c-c90371af6e66
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
Data1: 0xaf043a0a,
Data2: 0xb34d,
Data3: 0x4f86,
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
}
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
Data1: 0xd9ee00de,
Data2: 0xc1ef,
Data3: 0x4617,
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
}
var (
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
)
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
Data1: 0x7bc43cbf,
Data2: 0x37ba,
Data3: 0x45f1,
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
}
type wtFwpmL2Flags uint32
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
var cFWPM_CONDITION_FLAGS = windows.GUID{
Data1: 0x632ce23b,
Data2: 0x5167,
Data3: 0x435c,
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
}
type wtFwpmFlags uint32
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
// Defined in fwpmtypes.h
type wtFwpmFilterFlags uint32
const (
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
)
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
Data1: 0xc38d57d1,
Data2: 0x05a7,
Data3: 0x4c33,
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
}
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
Data1: 0xe1cd9fe7,
Data2: 0xf4b5,
Data3: 0x4273,
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
}
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
Data1: 0x4a72393b,
Data2: 0x319f,
Data3: 0x44bc,
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
}
// a3b42c97-9f04-4672-b87e-cee9c483257f
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
Data1: 0xa3b42c97,
Data2: 0x9f04,
Data3: 0x4672,
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
}
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
Data1: 0x94c44912,
Data2: 0x9d6f,
Data3: 0x4ebf,
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
}
// d4220bd3-62ce-4f08-ae88-b56e8526df50
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
Data1: 0xd4220bd3,
Data2: 0x62ce,
Data3: 0x4f08,
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
}
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
type wtFwpBitmapArray64 struct {
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
}
// FWP_BYTE_ARRAY6 defined in fwtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
type wtFwpByteArray6 struct {
byteArray6 [6]uint8 // Windows type: [6]UINT8
}
// FWP_BYTE_ARRAY16 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
type wtFwpByteArray16 struct {
byteArray16 [16]uint8 // Windows type [16]UINT8
}
// FWP_CONDITION_VALUE0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
type wtFwpConditionValue0 wtFwpValue0
// FWP_DATA_TYPE defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
type wtFwpDataType uint
const (
cFWP_EMPTY wtFwpDataType = 0
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
)
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
type wtFwpV4AddrAndMask struct {
addr uint32
mask uint32
}
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
type wtFwpV6AddrAndMask struct {
addr [16]uint8
prefixLength uint8
}
// FWP_VALUE0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
type wtFwpValue0 struct {
_type wtFwpDataType
value uintptr
}
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
type wtFwpmDisplayData0 struct {
name *uint16 // Windows type: *wchar_t
description *uint16 // Windows type: *wchar_t
}
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
type wtFwpmFilterCondition0 struct {
fieldKey windows.GUID // Windows type: GUID
matchType wtFwpMatchType
conditionValue wtFwpConditionValue0
}
// FWPM_PROVIDER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
type wtFwpProvider0 struct {
providerKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags uint32
providerData wtFwpByteBlob
serviceName *uint16 // Windows type: *wchar_t
}
type wtFwpmSessionFlagsValue uint32
const (
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
)
// FWPM_SESSION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
type wtFwpmSession0 struct {
sessionKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmSessionFlagsValue // Windows type UINT32
txnWaitTimeoutInMSec uint32
processId uint32 // Windows type: DWORD
sid *windows.SID
username *uint16 // Windows type: *wchar_t
kernelMode uint8 // Windows type: BOOL
}
type wtFwpmSublayerFlags uint32
const (
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
)
// FWPM_SUBLAYER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
type wtFwpmSublayer0 struct {
subLayerKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmSublayerFlags
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
weight uint16
}
// Defined in rpcdce.h
type wtRpcCAuthN uint32
const (
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
)
// FWPM_PROVIDER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
type wtFwpmProvider0 struct {
providerKey windows.GUID
displayData wtFwpmDisplayData0
flags uint32
providerData wtFwpByteBlob
serviceName *uint16
}
type wtIPProto uint32
const (
cIPPROTO_ICMP wtIPProto = 1
cIPPROTO_ICMPV6 wtIPProto = 58
cIPPROTO_TCP wtIPProto = 6
cIPPROTO_UDP wtIPProto = 17
)
const (
cFWP_ACTRL_MATCH_FILTER = 1
)

View File

@@ -1,92 +0,0 @@
//go:build windows && (386 || arm)
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
wtFwpByteBlob_Size = 8
wtFwpByteBlob_data_Offset = 4
wtFwpConditionValue0_Size = 8
wtFwpConditionValue0_uint8_Offset = 4
wtFwpmDisplayData0_Size = 8
wtFwpmDisplayData0_description_Offset = 4
wtFwpmFilter0_Size = 152
wtFwpmFilter0_displayData_Offset = 16
wtFwpmFilter0_flags_Offset = 24
wtFwpmFilter0_providerKey_Offset = 28
wtFwpmFilter0_providerData_Offset = 32
wtFwpmFilter0_layerKey_Offset = 40
wtFwpmFilter0_subLayerKey_Offset = 56
wtFwpmFilter0_weight_Offset = 72
wtFwpmFilter0_numFilterConditions_Offset = 80
wtFwpmFilter0_filterCondition_Offset = 84
wtFwpmFilter0_action_Offset = 88
wtFwpmFilter0_providerContextKey_Offset = 112
wtFwpmFilter0_reserved_Offset = 128
wtFwpmFilter0_filterID_Offset = 136
wtFwpmFilter0_effectiveWeight_Offset = 144
wtFwpmFilterCondition0_Size = 28
wtFwpmFilterCondition0_matchType_Offset = 16
wtFwpmFilterCondition0_conditionValue_Offset = 20
wtFwpmSession0_Size = 48
wtFwpmSession0_displayData_Offset = 16
wtFwpmSession0_flags_Offset = 24
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
wtFwpmSession0_processId_Offset = 32
wtFwpmSession0_sid_Offset = 36
wtFwpmSession0_username_Offset = 40
wtFwpmSession0_kernelMode_Offset = 44
wtFwpmSublayer0_Size = 44
wtFwpmSublayer0_displayData_Offset = 16
wtFwpmSublayer0_flags_Offset = 24
wtFwpmSublayer0_providerKey_Offset = 28
wtFwpmSublayer0_providerData_Offset = 32
wtFwpmSublayer0_weight_Offset = 40
wtFwpProvider0_Size = 40
wtFwpProvider0_displayData_Offset = 16
wtFwpProvider0_flags_Offset = 24
wtFwpProvider0_providerData_Offset = 28
wtFwpProvider0_serviceName_Offset = 36
wtFwpTokenInformation_Size = 16
wtFwpValue0_Size = 8
wtFwpValue0_value_Offset = 4
)
// FWPM_FILTER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
type wtFwpmFilter0 struct {
filterKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmFilterFlags
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
layerKey windows.GUID // Windows type: GUID
subLayerKey windows.GUID // Windows type: GUID
weight wtFwpValue0
numFilterConditions uint32
filterCondition *wtFwpmFilterCondition0
action wtFwpmAction0
offset1 [4]byte // Layout correction field
providerContextKey windows.GUID // Windows type: GUID
reserved *windows.GUID // Windows type: *GUID
offset2 [4]byte // Layout correction field
filterID uint64
effectiveWeight wtFwpValue0
}

View File

@@ -1,89 +0,0 @@
//go:build windows && (amd64 || arm64)
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
wtFwpByteBlob_Size = 16
wtFwpByteBlob_data_Offset = 8
wtFwpConditionValue0_Size = 16
wtFwpConditionValue0_uint8_Offset = 8
wtFwpmDisplayData0_Size = 16
wtFwpmDisplayData0_description_Offset = 8
wtFwpmFilter0_Size = 200
wtFwpmFilter0_displayData_Offset = 16
wtFwpmFilter0_flags_Offset = 32
wtFwpmFilter0_providerKey_Offset = 40
wtFwpmFilter0_providerData_Offset = 48
wtFwpmFilter0_layerKey_Offset = 64
wtFwpmFilter0_subLayerKey_Offset = 80
wtFwpmFilter0_weight_Offset = 96
wtFwpmFilter0_numFilterConditions_Offset = 112
wtFwpmFilter0_filterCondition_Offset = 120
wtFwpmFilter0_action_Offset = 128
wtFwpmFilter0_providerContextKey_Offset = 152
wtFwpmFilter0_reserved_Offset = 168
wtFwpmFilter0_filterID_Offset = 176
wtFwpmFilter0_effectiveWeight_Offset = 184
wtFwpmFilterCondition0_Size = 40
wtFwpmFilterCondition0_matchType_Offset = 16
wtFwpmFilterCondition0_conditionValue_Offset = 24
wtFwpmSession0_Size = 72
wtFwpmSession0_displayData_Offset = 16
wtFwpmSession0_flags_Offset = 32
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
wtFwpmSession0_processId_Offset = 40
wtFwpmSession0_sid_Offset = 48
wtFwpmSession0_username_Offset = 56
wtFwpmSession0_kernelMode_Offset = 64
wtFwpmSublayer0_Size = 72
wtFwpmSublayer0_displayData_Offset = 16
wtFwpmSublayer0_flags_Offset = 32
wtFwpmSublayer0_providerKey_Offset = 40
wtFwpmSublayer0_providerData_Offset = 48
wtFwpmSublayer0_weight_Offset = 64
wtFwpProvider0_Size = 64
wtFwpProvider0_displayData_Offset = 16
wtFwpProvider0_flags_Offset = 32
wtFwpProvider0_providerData_Offset = 40
wtFwpProvider0_serviceName_Offset = 56
wtFwpValue0_Size = 16
wtFwpValue0_value_Offset = 8
)
// FWPM_FILTER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
type wtFwpmFilter0 struct {
filterKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmFilterFlags // Windows type: UINT32
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
layerKey windows.GUID // Windows type: GUID
subLayerKey windows.GUID // Windows type: GUID
weight wtFwpValue0
numFilterConditions uint32
filterCondition *wtFwpmFilterCondition0
action wtFwpmAction0
offset1 [4]byte // Layout correction field
providerContextKey windows.GUID // Windows type: GUID
reserved *windows.GUID // Windows type: *GUID
filterID uint64
effectiveWeight wtFwpValue0
}

View File

@@ -1,130 +0,0 @@
// Code generated by 'go generate'; DO NOT EDIT.
package dnsfw
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
)
func fwpmEngineClose0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmFreeMemory0(p unsafe.Pointer) {
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
return
}
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}

View File

@@ -16,7 +16,6 @@ import (
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry" "github.com/netbirdio/netbird/client/internal/winregistry"
) )
@@ -72,7 +71,6 @@ type registryConfigurator struct {
routingAll bool routingAll bool
gpo bool gpo bool
nrptEntryCount int nrptEntryCount int
dnsFirewall dnsfw.Manager
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -92,9 +90,8 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
} }
configurator := &registryConfigurator{ configurator := &registryConfigurator{
guid: guid, guid: guid,
gpo: useGPO, gpo: useGPO,
dnsFirewall: dnsfw.New(),
} }
if err := configurator.configureInterface(); err != nil { if err := configurator.configureInterface(); err != nil {
@@ -172,8 +169,16 @@ func (r *registryConfigurator) disableWINSForInterface() error {
} }
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if err := r.applyRouteAll(config); err != nil { if config.RouteAll {
return err if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
} else if r.routingAll {
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
r.updateState(stateManager) r.updateState(stateManager)
@@ -215,35 +220,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil return nil
} }
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
if config.RouteAll {
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
return fmt.Errorf("dns firewall: %w", err)
}
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
if dErr := r.dnsFirewall.Disable(); dErr != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
}
return nberrors.FormatErrorOrNil(merr)
}
return nil
}
if err := r.dnsFirewall.Disable(); err != nil {
log.Errorf("disable dns firewall: %v", err)
}
if !r.routingAll {
return nil
}
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
return nil
}
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid, Guid: r.guid,
@@ -430,10 +406,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := r.dnsFirewall.Disable(); err != nil {
log.Errorf("disable dns firewall: %v", err)
}
go r.flushDNSCache() go r.flushDNSCache()
return nil return nil

View File

@@ -8,8 +8,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
) )
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up // TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
@@ -36,9 +34,8 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
}() }()
cfg := &registryConfigurator{ cfg := &registryConfigurator{
guid: testGUID, guid: testGUID,
gpo: false, gpo: false,
dnsFirewall: dnsfw.New(),
} }
// Create 125 domains which will result in 3 NRPT rules (50+50+25) // Create 125 domains which will result in 3 NRPT rules (50+50+25)
@@ -137,9 +134,8 @@ func TestNRPTDomainBatching(t *testing.T) {
}() }()
cfg := &registryConfigurator{ cfg := &registryConfigurator{
guid: testGUID, guid: testGUID,
gpo: false, gpo: false,
dnsFirewall: dnsfw.New(),
} }
testCases := []struct { testCases := []struct {

View File

@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
} }
// UpdateConnector updates an existing connector in Dex storage. // UpdateConnector updates an existing connector in Dex storage.
// It merges incoming updates with existing values to prevent data loss on partial updates. // It overlays user-mutable config fields (issuer, clientID, clientSecret,
// redirectURI) onto the stored connector config, and updates the connector name
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
// partial updates preserve create-time defaults such as scopes, claimMapping,
// and userIDKey.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
oldCfg, err := p.parseStorageConnector(old) if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
if err != nil { return storage.Connector{}, errors.New("connector type change not allowed")
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
} }
mergeConnectorConfig(cfg, oldCfg) configData, err := overlayConnectorConfig(old.Config, cfg)
storageConn, err := p.buildStorageConnector(cfg)
if err != nil { if err != nil {
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err) return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
} }
return storageConn, nil
name := cfg.Name
if name == "" {
name = old.Name
}
return storage.Connector{
ID: cfg.ID,
Type: old.Type,
Name: name,
Config: configData,
}, nil
}); err != nil { }); err != nil {
return fmt.Errorf("failed to update connector: %w", err) return fmt.Errorf("failed to update connector: %w", err)
} }
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
return nil return nil
} }
// mergeConnectorConfig preserves existing values for empty fields in the update. // overlayConnectorConfig writes only the user-mutable fields onto the existing
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) { // stored config, preserving every other field (scopes, claimMapping, userIDKey,
if cfg.ClientSecret == "" { // insecure flags, etc.). Empty fields on cfg leave the existing value alone.
cfg.ClientSecret = oldCfg.ClientSecret func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
var m map[string]any
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
return nil, err
} }
if cfg.RedirectURI == "" { if cfg.Issuer != "" {
cfg.RedirectURI = oldCfg.RedirectURI m["issuer"] = cfg.Issuer
} }
if cfg.Issuer == "" && cfg.Type == oldCfg.Type { if cfg.ClientID != "" {
cfg.Issuer = oldCfg.Issuer m["clientID"] = cfg.ClientID
} }
if cfg.ClientID == "" { if cfg.ClientSecret != "" {
cfg.ClientID = oldCfg.ClientID m["clientSecret"] = cfg.ClientSecret
} }
if cfg.Name == "" { if cfg.RedirectURI != "" {
cfg.Name = oldCfg.Name m["redirectURI"] = cfg.RedirectURI
} }
return encodeConnectorConfig(m)
} }
// DeleteConnector removes a connector from Dex storage. // DeleteConnector removes a connector from Dex storage.
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["getUserInfo"] = true oidcConfig["getUserInfo"] = true
case "entra": case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
// Entra issues sub as a per-app pairwise identifier that does not match
// the stable Object ID.
oidcConfig["userIDKey"] = "oid"
case "okta": case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid": case "pocketid":

205
idp/dex/connector_test.go Normal file
View File

@@ -0,0 +1,205 @@
package dex
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestProvider(t *testing.T) (*Provider, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
require.NoError(t, err)
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
require.NoError(t, err)
return &Provider{storage: s, logger: logger}, func() {
_ = s.Close()
_ = os.RemoveAll(tmpDir)
}
}
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
cfg := &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "client-secret",
}
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
}
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
// ensures the Entra userIDKey override does not leak into other OIDC providers,
// which already use a stable sub claim.
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
t.Run(typ, func(t *testing.T) {
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
_, ok := m["userIDKey"]
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
})
}
}
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
created, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "old-secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
require.Equal(t, "entra-test", created.ID)
// Rotate only the client secret.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
}
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
// Seed a connector directly into storage without userIDKey
preFixConfig, err := json.Marshal(map[string]any{
"issuer": "https://login.microsoftonline.com/tid/v2.0",
"clientID": "client-id",
"clientSecret": "old-secret",
"redirectURI": "https://example.com/oauth2/callback",
"scopes": []string{"openid", "profile", "email"},
"claimMapping": map[string]string{"email": "preferred_username"},
})
require.NoError(t, err)
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
ID: "entra-prefix",
Type: "oidc",
Name: "Entra",
Config: preFixConfig,
}))
// Rotate client secret via UpdateConnector.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-prefix",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"])
_, has := m["userIDKey"]
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
}
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
// Attempt to switch the connector to okta.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "okta",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "connector type change not allowed")
// stored connector type/config unchanged after the rejected update.
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
assert.Equal(t, "oidc", conn.Type)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "oid", m["userIDKey"])
}
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/old/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
Issuer: "https://login.microsoftonline.com/new/v2.0",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
}

View File

@@ -11,6 +11,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers // Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore pkceVerifierStore *PKCEVerifierStore
// tokenTTL is the lifetime of one-time tokens generated for proxy
// authentication. Defaults to defaultProxyTokenTTL when zero.
tokenTTL time.Duration
// snapshotBatchSize is the number of mappings per gRPC message during
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
snapshotBatchSize int
cancel context.CancelFunc cancel context.CancelFunc
} }
const pkceVerifierTTL = 10 * time.Minute const pkceVerifierTTL = 10 * time.Minute
const defaultProxyTokenTTL = 5 * time.Minute
const defaultSnapshotBatchSize = 500
func snapshotBatchSizeFromEnv() int {
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultSnapshotBatchSize
}
// proxyTokenTTL returns the configured token TTL or the default when unset.
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
if s.tokenTTL > 0 {
return s.tokenTTL
}
return defaultProxyTokenTTL
}
// proxyConnection represents a connected proxy // proxyConnection represents a connected proxy
type proxyConnection struct { type proxyConnection struct {
proxyID string proxyID string
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
proxyManager: proxyMgr, proxyManager: proxyMgr,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel, cancel: cancel,
} }
go s.cleanupStaleProxies(ctx) go s.cleanupStaleProxies(ctx)
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel, cancel: cancel,
} }
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities // Register proxy in database with capabilities
var caps *proxy.Capabilities var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil { if c := req.GetCapabilities(); c != nil {
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.CompareAndDelete(proxyID, conn) cancel()
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
return status.Errorf(codes.Internal, "register proxy in database: %v", err) return status.Errorf(codes.Internal, "register proxy in database: %v", err)
} }
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"session_id": sessionID, "session_id": sessionID,
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}() }()
if err := s.sendSnapshot(ctx, conn); err != nil {
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.heartbeat(connCtx, proxyRecord) go s.heartbeat(connCtx, proxyRecord)
select { select {
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return err return err
} }
// Send mappings in batches to reduce per-message gRPC overhead while
// staying well within the default 4 MB message size limit.
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
}
if len(mappings) == 0 { if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true, InitialSyncComplete: true,
}); err != nil { }); err != nil {
return fmt.Errorf("send snapshot completion: %w", err) return fmt.Errorf("send snapshot completion: %w", err)
} }
return nil
}
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
} }
return nil return nil
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
continue continue
} }
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil { if err != nil {
log.WithFields(log.Fields{ return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
} }
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
conn := value.(*proxyConnection) conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID) resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil { if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
return true return true
} }
select { select {
case conn.sendChan <- resp: case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID) log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default: default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
} }
return true return true
}) })
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
} }
msg := s.perProxyMessage(updateResponse, proxyID) msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil { if msg == nil {
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
continue continue
} }
select { select {
case conn.sendChan <- msg: case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default: default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
} }
} }
} }
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
// perProxyMessage returns a copy of update with a fresh one-time token for // perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is // create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal. // used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped). // Returns nil if token generation fails; the caller must disconnect the
// proxy so it can resync via a fresh snapshot on reconnect.
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping { for _, mapping := range update.Mapping {
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
continue continue
} }
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
if err != nil { if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil return nil

View File

@@ -0,0 +1,174 @@
package grpc
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// recordingStream captures all messages sent via Send so tests can inspect
// batching behaviour without a real gRPC transport.
type recordingStream struct {
grpc.ServerStream
messages []*proto.GetMappingUpdateResponse
}
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
s.messages = append(s.messages, m)
return nil
}
func (s *recordingStream) Context() context.Context { return context.Background() }
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
func (s *recordingStream) SetTrailer(metadata.MD) {}
func (s *recordingStream) SendMsg(any) error { return nil }
func (s *recordingStream) RecvMsg(any) error { return nil }
// makeServices creates n enabled services assigned to the given cluster.
func makeServices(n int, cluster string) []*rpservice.Service {
services := make([]*rpservice.Service, n)
for i := range n {
services[i] = &rpservice.Service{
ID: fmt.Sprintf("svc-%d", i),
AccountID: "acct-1",
Name: fmt.Sprintf("svc-%d", i),
Domain: fmt.Sprintf("svc-%d.example.com", i),
ProxyCluster: cluster,
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
},
}
}
return services
}
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
t.Helper()
s := &ProxyServiceServer{
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
snapshotBatchSize: batchSize,
}
s.SetProxyController(newTestProxyController())
return s
}
func TestSendSnapshot_BatchesMappings(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
stream: stream,
}
err := s.sendSnapshot(context.Background(), conn)
require.NoError(t, err)
// Expect ceil(7/3) = 3 messages
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
assert.Len(t, stream.messages[1].Mapping, 3)
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
assert.Len(t, stream.messages[2].Mapping, 1)
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
// Verify all service IDs are present exactly once
seen := make(map[string]bool)
for _, msg := range stream.messages {
for _, m := range msg.Mapping {
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
seen[m.Id] = true
}
}
assert.Len(t, seen, totalServices)
}
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 6 // exactly 2 batches
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 2)
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete)
assert.Len(t, stream.messages[1].Mapping, 3)
assert.True(t, stream.messages[1].InitialSyncComplete)
}
func TestSendSnapshot_SingleBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
assert.Len(t, stream.messages[0].Mapping, totalServices)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}

View File

@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. // registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10) ch := make(chan *proto.GetMappingUpdateResponse, 10)
ctx, cancel := context.WithCancel(context.Background())
conn := &proxyConnection{ conn := &proxyConnection{
proxyID: proxyID, proxyID: proxyID,
address: clusterAddr, address: clusterAddr,
capabilities: caps, capabilities: caps,
sendChan: ch, sendChan: ch,
ctx: ctx,
cancel: cancel,
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)

View File

@@ -1040,6 +1040,13 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
for attempt := 1; attempt <= maxAttempts; attempt++ { for attempt := 1; attempt <= maxAttempts; attempt++ {
if am.isCacheFresh(ctx, accountUsers, data) { if am.isCacheFresh(ctx, accountUsers, data) {
// Catch the silent vacuous-fresh case: empty accountUsers map + empty cache data
// → isCacheFresh returns true without iterating, skipping refreshCache,
// returning empty data. This matters when the account is mostly integration
// (SCIM) users and the InternalCache has been flushed.
if len(accountUsers) == 0 && len(data) == 0 {
log.WithContext(ctx).Warnf("lookupCache VACUOUS FRESH: accountUsers map is empty AND cache is empty for account %s — returning empty data without triggering loadAccount", accountID)
}
return data, nil return data, nil
} }

View File

@@ -7,11 +7,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -45,11 +42,6 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
// getAllCountries retrieves a list of all countries // getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil { if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready // TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -71,11 +63,6 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
// getCitiesByCountry retrieves a list of cities based on the given country code // getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r) vars := mux.Vars(r)
countryCode := vars["country"] countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) { if !countryCodeRegex.MatchString(countryCode) {
@@ -102,27 +89,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities) util.WriteJSONObject(r.Context(), w, cities)
} }
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country { func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{ return api.Country{
CountryName: country.CountryName, CountryName: country.CountryName,

View File

@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken() _, plainToken, err := GenerateInviteToken()
require.NoError(t, err) require.NoError(t, err)
// Modify one character in the secret part replacement := "X"
modifiedToken := plainToken[:5] + "X" + plainToken[6:] if plainToken[5] == 'X' {
replacement = "Y"
}
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
err = ValidateInviteToken(modifiedToken) err = ValidateInviteToken(modifiedToken)
require.Error(t, err) require.Error(t, err)
} }

View File

@@ -1061,15 +1061,28 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
queriedUsers = append(queriedUsers, usersFromIntegration...) queriedUsers = append(queriedUsers, usersFromIntegration...)
} }
idpManagerNil := isNil(am.idpManager)
idpManagerEmbedded := !idpManagerNil && IsEmbeddedIdp(am.idpManager)
userInfosMap := make(map[string]*types.UserInfo) userInfosMap := make(map[string]*types.UserInfo)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 { if len(queriedUsers) == 0 {
var earlyReturnEmpty int
var earlyReturnEmptySamples []string
for _, accountUser := range accountUsers { for _, accountUser := range accountUsers {
info, err := accountUser.ToUserInfo(nil) info, err := accountUser.ToUserInfo(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !accountUser.IsServiceUser && (info.Email == "" || info.Name == "") {
earlyReturnEmpty++
if len(earlyReturnEmptySamples) < 50 {
earlyReturnEmptySamples = append(earlyReturnEmptySamples,
fmt.Sprintf("%s(issued=%s,db.email=%q,db.name=%q)",
accountUser.Id, accountUser.Issued, accountUser.Email, accountUser.Name))
}
}
// Try to decode Dex user ID to extract the IdP ID (connector ID) // Try to decode Dex user ID to extract the IdP ID (connector ID)
if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" { if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" {
info.IdPID = connectorID info.IdPID = connectorID
@@ -1077,17 +1090,61 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
userInfosMap[accountUser.Id] = info userInfosMap[accountUser.Id] = info
} }
log.WithContext(ctx).Warnf("BuildUserInfosForAccount EARLY RETURN: queriedUsers empty, returning %d users with DB-only data (idpManagerNil=%v, idpManagerEmbedded=%v). %d non-service users have empty email/name in DB. Samples: %v",
len(accountUsers), idpManagerNil, idpManagerEmbedded, earlyReturnEmpty, earlyReturnEmptySamples)
// Same canonical-truth final scan, also on the early-return path
var finalEmptyEmail, finalEmptyName int
var finalEmptySamples []string
for id, info := range userInfosMap {
if info.IsServiceUser {
continue
}
if info.Email == "" {
finalEmptyEmail++
}
if info.Name == "" {
finalEmptyName++
}
if (info.Email == "" || info.Name == "") && len(finalEmptySamples) < 200 {
finalEmptySamples = append(finalEmptySamples,
fmt.Sprintf("%s(email=%q,name=%q,issued=%s)", id, info.Email, info.Name, info.Issued))
}
}
if finalEmptyEmail > 0 || finalEmptyName > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount FINAL (early-return path): returning %d UserInfo entries — %d with empty email, %d with empty name. Samples: %v",
len(userInfosMap), finalEmptyEmail, finalEmptyName, finalEmptySamples)
}
return userInfosMap, nil return userInfosMap, nil
} }
var cacheHitEmpty, fallbackMiss int
var cacheHitEmptySamples, fallbackMissSamples []string
for _, localUser := range accountUsers { for _, localUser := range accountUsers {
var info *types.UserInfo var info *types.UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
if !localUser.IsServiceUser && (queriedUser.Email == "" || queriedUser.Name == "") {
cacheHitEmpty++
if len(cacheHitEmptySamples) < 50 {
cacheHitEmptySamples = append(cacheHitEmptySamples,
fmt.Sprintf("%s(cache.email=%q,cache.name=%q,db.email=%q,db.name=%q)",
localUser.Id, queriedUser.Email, queriedUser.Name, localUser.Email, localUser.Name))
}
}
info, err = localUser.ToUserInfo(queriedUser) info, err = localUser.ToUserInfo(queriedUser)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
if !localUser.IsServiceUser {
fallbackMiss++
if len(fallbackMissSamples) < 50 {
fallbackMissSamples = append(fallbackMissSamples,
fmt.Sprintf("%s(issued=%s,db.email=%q,db.name=%q)",
localUser.Id, localUser.Issued, localUser.Email, localUser.Name))
}
}
name := "" name := ""
if localUser.IsServiceUser { if localUser.IsServiceUser {
name = localUser.ServiceUserName name = localUser.ServiceUserName
@@ -1111,6 +1168,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
userInfosMap[info.ID] = info userInfosMap[info.ID] = info
} }
if cacheHitEmpty > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount: %d users found in cache with empty email or name (cache pollution). Samples: %v",
cacheHitEmpty, cacheHitEmptySamples)
}
if fallbackMiss > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount: %d non-service users missed both caches (will get empty Name in API response from fallback). Samples: %v",
fallbackMiss, fallbackMissSamples)
}
// Canonical-truth log: scan what we are actually about to return to the handler.
// This catches empties from any code path (early-return, cache-hit-empty, fallback-miss,
// or anything we haven't identified yet).
var finalEmptyEmail, finalEmptyName int
var finalEmptySamples []string
for id, info := range userInfosMap {
if info.IsServiceUser {
continue
}
if info.Email == "" {
finalEmptyEmail++
}
if info.Name == "" {
finalEmptyName++
}
if (info.Email == "" || info.Name == "") && len(finalEmptySamples) < 200 {
finalEmptySamples = append(finalEmptySamples,
fmt.Sprintf("%s(email=%q,name=%q,issued=%s)", id, info.Email, info.Name, info.Issued))
}
}
if finalEmptyEmail > 0 || finalEmptyName > 0 {
log.WithContext(ctx).Warnf("BuildUserInfosForAccount FINAL: returning %d UserInfo entries — %d with empty email, %d with empty name. Samples: %v",
len(userInfosMap), finalEmptyEmail, finalEmptyName, finalEmptySamples)
}
return userInfosMap, nil return userInfosMap, nil
} }

View File

@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping) mappingsByID := make(map[string]*proto.ProxyMapping)
for i := 0; i < 2; i++ { for {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
for _, m := range msg.GetMapping() { for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m mappingsByID[m.GetId()] = m
} }
if msg.GetInitialSyncComplete() {
break
}
} }
// Should receive 2 mappings total // Should receive 2 mappings total
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0) mappings := make([]*proto.ProxyMapping, 0)
for i := 0; i < 2; i++ { for {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...) mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
} }
// Should receive the 2 mappings matching the cluster // Should receive the 2 mappings matching the cluster
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
clusterAddress := "test.proxy.io" clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect" proxyID := "test-proxy-reconnect"
// Helper to receive all mappings from a stream receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping var mappings []*proto.ProxyMapping
for i := 0; i < count; i++ { for {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...) mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
} }
return mappings return mappings
} }
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
firstMappings := receiveMappings(stream1, 2) firstMappings := receiveMappings(stream1)
cancel1() cancel1()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
secondMappings := receiveMappings(stream2, 2) secondMappings := receiveMappings(stream2)
// Should receive the same mappings // Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings), assert.Equal(t, len(firstMappings), len(secondMappings),
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
} }
} }
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) { receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for i := 0; i < 2; i++ { for {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
applyMappings(msg.GetMapping()) applyMappings(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
} }
} }
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0 count := 0
for i := 0; i < 2; i++ { for {
msg, err := stream.Recv() msg, err := stream.Recv()
require.NoError(t, err) require.NoError(t, err)
count += len(msg.GetMapping()) count += len(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
} }
mu.Lock() mu.Lock()
@@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
for i := 0; i < 2; i++ { for {
_, err := stream1.Recv() msg, err := stream1.Recv()
require.NoError(t, err) require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
} }
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
@@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
}) })
require.NoError(t, err) require.NoError(t, err)
for i := 0; i < 2; i++ { for {
_, err := stream2.Recv() msg, err := stream2.Recv()
require.NoError(t, err) require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
} }
cancel1() cancel1()

View File

@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
operation := func() error { operation := func() error {
s.Logger.Debug("connecting to management mapping stream") s.Logger.Debug("connecting to management mapping stream")
initialSyncDone = false
if s.healthChecker != nil { if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false) s.healthChecker.SetManagementConnected(false)
} }
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
return ctx.Err() return ctx.Err()
} }
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
for { for {
// Check for context completion to gracefully shutdown. // Check for context completion to gracefully shutdown.
select { select {
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.processMappings(ctx, msg.GetMapping()) s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed") s.Logger.Debug("Processing mapping update completed")
if !*initialSyncDone && msg.GetInitialSyncComplete() { if !*initialSyncDone {
if s.healthChecker != nil { for _, m := range msg.GetMapping() {
s.healthChecker.SetInitialSyncComplete() snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
} }
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
} }
} }
} }
} }
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
s.portMu.RLock()
var stale []*proto.ProxyMapping
for svcID, mapping := range s.lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, mapping)
}
}
s.portMu.RUnlock()
for _, mapping := range stale {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Info("Removing stale mapping absent from snapshot")
s.removeMapping(ctx, mapping)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
for _, mapping := range mappings { for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{ s.Logger.WithFields(log.Fields{

View File

@@ -0,0 +1,227 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
// so we can verify it without triggering removeMapping (which requires full
// server wiring). This keeps the test focused on the detection algorithm.
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
var stale []types.ServiceID
for svcID := range lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, svcID)
}
}
return stale
}
// TestStaleDetection_PartialOverlap verifies that only services absent from
// the snapshot are flagged as stale.
func TestStaleDetection_PartialOverlap(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
"svc-stale-a": {Id: "svc-stale-a"},
"svc-stale-b": {Id: "svc-stale-b"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
"svc-3": {}, // new service, not in local
}
stale := collectStaleIDs(local, snapshot)
assert.Len(t, stale, 2)
staleSet := make(map[types.ServiceID]struct{})
for _, id := range stale {
staleSet[id] = struct{}{}
}
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
}
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
func TestStaleDetection_AllStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
assert.Len(t, stale, 2)
}
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
func TestStaleDetection_NoneStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
stale := collectStaleIDs(local, snapshot)
assert.Empty(t, stale)
}
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
func TestStaleDetection_EmptyLocal(t *testing.T) {
stale := collectStaleIDs(
map[types.ServiceID]*proto.ProxyMapping{},
map[types.ServiceID]struct{}{"svc-1": {}},
)
assert.Empty(t, stale)
}
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
// local mappings are present in the snapshot (removeMapping is never called).
func TestReconcileSnapshot_NoStale(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
snapshotIDs := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
// This should not panic — no stale entries means removeMapping is never called.
s.reconcileSnapshot(context.Background(), snapshotIDs)
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
}
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
// no local mappings.
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
assert.Empty(t, s.lastMappings)
}
// --- handleMappingStream tests for batched snapshot ID accumulation ---
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
// marked done only after the final InitialSyncComplete message, even when
// the snapshot arrives in multiple batches.
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // batch 1: no sync-complete
{}, // batch 2: no sync-complete
{InitialSyncComplete: true}, // batch 3: sync done
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync should be marked done after final batch")
}
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
// arriving after InitialSyncComplete do not trigger a second reconciliation.
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// Simulate state left over from a previous sync.
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // post-sync empty message — must not reconcile
},
}
syncDone := true // sync already completed in a previous stream
err := s.handleMappingStream(context.Background(), stream, &syncDone)
require.NoError(t, err)
assert.Len(t, s.lastMappings, 2,
"post-sync messages must not trigger reconciliation — all entries should survive")
}
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
// stream closes before sync completes, no reconciliation occurs.
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
stream := &mockMappingStream{} // no messages → immediate EOF
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
}
// mockErrRecvStream returns an error on the second Recv to verify
// handleMappingStream returns without completing sync.
type mockErrRecvStream struct {
mockMappingStream
calls int
}
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
m.calls++
if m.calls == 1 {
return &proto.GetMappingUpdateResponse{}, nil
}
return nil, io.ErrUnexpectedEOF
}
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
syncDone := false
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
assert.Error(t, err)
assert.False(t, syncDone)
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
}