mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-10 18:59:55 +00:00
Compare commits
7 Commits
windows-dn
...
debug-logs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6240dcd96a | ||
|
|
f532976e05 | ||
|
|
71a400f90f | ||
|
|
bfeb9b19ec | ||
|
|
b19b7464ea | ||
|
|
cfb1b3fe31 | ||
|
|
3c28d29725 |
@@ -92,9 +92,6 @@ linters:
|
|||||||
- linters:
|
- linters:
|
||||||
- unused
|
- unused
|
||||||
path: client/firewall/iptables/rule\.go
|
path: client/firewall/iptables/rule\.go
|
||||||
- linters:
|
|
||||||
- unused
|
|
||||||
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
|
|
||||||
- linters:
|
- linters:
|
||||||
- gosec
|
- gosec
|
||||||
- mirror
|
- mirror
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
@@ -583,6 +584,9 @@ func isSensitiveEnvVar(key string) bool {
|
|||||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||||
|
|
||||||
|
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
|
||||||
|
}
|
||||||
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
|
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
|
||||||
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
|
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
|
||||||
if g.internalConfig.NetworkMonitor != nil {
|
if g.internalConfig.NetworkMonitor != nil {
|
||||||
@@ -607,6 +611,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
}
|
}
|
||||||
|
if g.internalConfig.DisableSSHAuth != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
||||||
|
}
|
||||||
|
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||||
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
@@ -633,6 +643,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
}
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
|
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
|
|||||||
@@ -5,16 +5,21 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
|||||||
anonymize: false,
|
anonymize: false,
|
||||||
input: map[string]any{
|
input: map[string]any{
|
||||||
jsonKeyServiceEnv: map[string]any{
|
jsonKeyServiceEnv: map[string]any{
|
||||||
"HOME": "/root",
|
"HOME": "/root",
|
||||||
"PATH": "/usr/bin",
|
"PATH": "/usr/bin",
|
||||||
"NB_LOG_LEVEL": "debug",
|
"NB_LOG_LEVEL": "debug",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
|||||||
anonymize: false,
|
anonymize: false,
|
||||||
input: map[string]any{
|
input: map[string]any{
|
||||||
jsonKeyServiceEnv: map[string]any{
|
jsonKeyServiceEnv: map[string]any{
|
||||||
"NB_SETUP_KEY": "abc123",
|
"NB_SETUP_KEY": "abc123",
|
||||||
"NB_API_TOKEN": "tok_xyz",
|
"NB_API_TOKEN": "tok_xyz",
|
||||||
"NB_LOG_LEVEL": "info",
|
"NB_LOG_LEVEL": "info",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
check: func(t *testing.T, params map[string]any) {
|
check: func(t *testing.T, params map[string]any) {
|
||||||
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
assert.Contains(t, anonNftables, "chain input {")
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
||||||
|
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
||||||
|
// excluded. When a new field is added to Config, this test fails until the
|
||||||
|
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
||||||
|
// the excluded set with a justification.
|
||||||
|
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||||
|
excluded := map[string]string{
|
||||||
|
"PrivateKey": "sensitive: WireGuard private key",
|
||||||
|
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||||
|
"SSHKey": "sensitive: SSH private key",
|
||||||
|
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||||
|
}
|
||||||
|
|
||||||
|
mURL, _ := url.Parse("https://api.example.com:443")
|
||||||
|
aURL, _ := url.Parse("https://admin.example.com:443")
|
||||||
|
bTrue := true
|
||||||
|
iVal := 42
|
||||||
|
cfg := &profilemanager.Config{
|
||||||
|
PrivateKey: "priv",
|
||||||
|
PreSharedKey: "psk",
|
||||||
|
ManagementURL: mURL,
|
||||||
|
AdminURL: aURL,
|
||||||
|
WgIface: "wt0",
|
||||||
|
WgPort: 51820,
|
||||||
|
NetworkMonitor: &bTrue,
|
||||||
|
IFaceBlackList: []string{"eth0"},
|
||||||
|
DisableIPv6Discovery: true,
|
||||||
|
RosenpassEnabled: true,
|
||||||
|
RosenpassPermissive: true,
|
||||||
|
ServerSSHAllowed: &bTrue,
|
||||||
|
EnableSSHRoot: &bTrue,
|
||||||
|
EnableSSHSFTP: &bTrue,
|
||||||
|
EnableSSHLocalPortForwarding: &bTrue,
|
||||||
|
EnableSSHRemotePortForwarding: &bTrue,
|
||||||
|
DisableSSHAuth: &bTrue,
|
||||||
|
SSHJWTCacheTTL: &iVal,
|
||||||
|
DisableClientRoutes: true,
|
||||||
|
DisableServerRoutes: true,
|
||||||
|
DisableDNS: true,
|
||||||
|
DisableFirewall: true,
|
||||||
|
BlockLANAccess: true,
|
||||||
|
BlockInbound: true,
|
||||||
|
DisableNotifications: &bTrue,
|
||||||
|
DNSLabels: domain.List{},
|
||||||
|
SSHKey: "sshkey",
|
||||||
|
NATExternalIPs: []string{"1.2.3.4"},
|
||||||
|
CustomDNSAddress: "1.1.1.1:53",
|
||||||
|
DisableAutoConnect: true,
|
||||||
|
DNSRouteInterval: 5 * time.Second,
|
||||||
|
ClientCertPath: "/tmp/cert",
|
||||||
|
ClientCertKeyPath: "/tmp/key",
|
||||||
|
LazyConnectionEnabled: true,
|
||||||
|
MTU: 1280,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, anonymize := range []bool{false, true} {
|
||||||
|
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymizer: newAnonymizerForTest(),
|
||||||
|
internalConfig: cfg,
|
||||||
|
anonymize: anonymize,
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
g.addCommonConfigFields(&sb)
|
||||||
|
rendered := sb.String() + renderAddConfigSpecific(g)
|
||||||
|
|
||||||
|
val := reflect.ValueOf(cfg).Elem()
|
||||||
|
typ := val.Type()
|
||||||
|
var missing []string
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
name := typ.Field(i).Name
|
||||||
|
if _, ok := excluded[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.Contains(rendered, name+":") {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
||||||
|
"Either render the field in addCommonConfigFields/addConfig, "+
|
||||||
|
"or add it to the excluded map with a justification.", missing)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
||||||
|
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
||||||
|
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
||||||
|
// production shape without needing to write an actual zip.
|
||||||
|
func renderAddConfigSpecific(g *BundleGenerator) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if g.anonymize {
|
||||||
|
if g.internalConfig.ManagementURL != nil {
|
||||||
|
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
||||||
|
}
|
||||||
|
if g.internalConfig.AdminURL != nil {
|
||||||
|
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("NATExternalIPs: x\n")
|
||||||
|
if g.internalConfig.CustomDNSAddress != "" {
|
||||||
|
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if g.internalConfig.ManagementURL != nil {
|
||||||
|
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
||||||
|
}
|
||||||
|
if g.internalConfig.AdminURL != nil {
|
||||||
|
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("NATExternalIPs: x\n")
|
||||||
|
if g.internalConfig.CustomDNSAddress != "" {
|
||||||
|
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAnonymizerForTest() *anonymize.Anonymizer {
|
||||||
|
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
|
|
||||||
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
|
|
||||||
// EnvPorts overrides the comma-separated list of remote ports to block.
|
|
||||||
// Empty disables the firewall.
|
|
||||||
EnvPorts = "NB_DNS_FIREWALL_PORTS"
|
|
||||||
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
|
|
||||||
// and the netbird daemon. Default mode also permits anything on the
|
|
||||||
// netbird tunnel interface, which is safer if NRPT is silently ignored
|
|
||||||
// by Windows but lets apps reach custom DNS servers via the tunnel.
|
|
||||||
EnvStrict = "NB_DNS_FIREWALL_STRICT"
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
|
|
||||||
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
|
|
||||||
var defaultBlockedPorts = []uint16{53, 853}
|
|
||||||
|
|
||||||
// blockedPorts returns the effective port list, honoring env overrides.
|
|
||||||
// A nil return means the firewall should not be installed.
|
|
||||||
func blockedPorts() []uint16 {
|
|
||||||
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
|
|
||||||
log.Infof("dns firewall disabled via %s", EnvDisable)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
override, ok := os.LookupEnv(EnvPorts)
|
|
||||||
if !ok {
|
|
||||||
return defaultBlockedPorts
|
|
||||||
}
|
|
||||||
|
|
||||||
var ports []uint16
|
|
||||||
for _, raw := range strings.Split(override, ",") {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
port, err := strconv.ParseUint(raw, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if port == 0 {
|
|
||||||
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ports = append(ports, uint16(port))
|
|
||||||
}
|
|
||||||
if len(ports) == 0 {
|
|
||||||
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ports
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBlockedPorts(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
disable string
|
|
||||||
ports string
|
|
||||||
setPorts bool
|
|
||||||
want []uint16
|
|
||||||
}{
|
|
||||||
{name: "default", want: defaultBlockedPorts},
|
|
||||||
{name: "disabled", disable: "true", want: nil},
|
|
||||||
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
|
|
||||||
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
|
|
||||||
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
|
|
||||||
{name: "override empty disables", ports: "", setPorts: true, want: nil},
|
|
||||||
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
|
|
||||||
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
|
|
||||||
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Setenv(EnvDisable, tc.disable)
|
|
||||||
if tc.setPorts {
|
|
||||||
t.Setenv(EnvPorts, tc.ports)
|
|
||||||
}
|
|
||||||
got := blockedPorts()
|
|
||||||
if !reflect.DeepEqual(got, tc.want) {
|
|
||||||
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
|
|
||||||
// managing the host's DNS, so that resolvers running on apps or libraries
|
|
||||||
// outside netbird cannot bypass the configured DNS path.
|
|
||||||
//
|
|
||||||
// Implementation is Windows-only (uses WFP). On other platforms New returns
|
|
||||||
// a no-op manager.
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
|
|
||||||
// to call multiple times.
|
|
||||||
type Manager interface {
|
|
||||||
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
|
|
||||||
Disable() error
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
type noopManager struct{}
|
|
||||||
|
|
||||||
func (noopManager) Enable(string, netip.Addr) error { return nil }
|
|
||||||
func (noopManager) Disable() error { return nil }
|
|
||||||
|
|
||||||
// New returns a no-op manager on non-Windows platforms.
|
|
||||||
func New() Manager {
|
|
||||||
return noopManager{}
|
|
||||||
}
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
|
||||||
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
|
|
||||||
)
|
|
||||||
|
|
||||||
type windowsManager struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
// session is the WFP engine handle. Zero when disabled.
|
|
||||||
session uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable installs the dns firewall. Strict mode propagates failures;
|
|
||||||
// non-strict mode logs and returns nil so partial protection is preserved.
|
|
||||||
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
ports := blockedPorts()
|
|
||||||
if len(ports) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.session != 0 {
|
|
||||||
if err := m.disableLocked(); err != nil {
|
|
||||||
return fmt.Errorf("reset existing dns firewall session: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
strict := strictMode()
|
|
||||||
|
|
||||||
luid, err := luidFromGUID(ifaceGUID)
|
|
||||||
if err != nil {
|
|
||||||
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := installConfig{
|
|
||||||
tunLUID: luid,
|
|
||||||
daemonExe: exe,
|
|
||||||
blockedPorts: ports,
|
|
||||||
strict: strict,
|
|
||||||
virtualDNSIP: virtualDNSIP,
|
|
||||||
}
|
|
||||||
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
|
|
||||||
session, installErr := installFilters(cfg)
|
|
||||||
if session == 0 {
|
|
||||||
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
|
|
||||||
}
|
|
||||||
|
|
||||||
if installErr != nil && strict {
|
|
||||||
_ = closeSession(session)
|
|
||||||
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.session = session
|
|
||||||
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
|
|
||||||
ifaceGUID, exe, ports, strict, virtualDNSIP)
|
|
||||||
if installErr != nil {
|
|
||||||
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *windowsManager) Disable() error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
return m.disableLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *windowsManager) disableLocked() error {
|
|
||||||
if m.session == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
session := m.session
|
|
||||||
m.session = 0
|
|
||||||
if err := closeSession(session); err != nil {
|
|
||||||
return fmt.Errorf("close wfp session: %w", err)
|
|
||||||
}
|
|
||||||
log.Info("dns firewall removed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// failOrLog returns err unchanged in strict mode. In non-strict mode the
|
|
||||||
// error is logged and nil is returned.
|
|
||||||
func (m *windowsManager) failOrLog(strict bool, err error) error {
|
|
||||||
if strict {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Errorf("dns firewall: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a Windows DNS firewall manager backed by WFP.
|
|
||||||
func New() Manager {
|
|
||||||
return &windowsManager{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// strictMode reports whether strict mode is enabled via env.
|
|
||||||
func strictMode() bool {
|
|
||||||
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// luidFromGUID converts a Windows interface GUID string to its LUID.
|
|
||||||
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("panic in luidFromGUID: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
guid, err := windows.GUIDFromString(ifaceGUID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("parse guid: %w", err)
|
|
||||||
}
|
|
||||||
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
|
|
||||||
uintptr(unsafe.Pointer(&guid)),
|
|
||||||
uintptr(unsafe.Pointer(&luid)),
|
|
||||||
)
|
|
||||||
if rc != 0 {
|
|
||||||
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
|
|
||||||
}
|
|
||||||
return luid, nil
|
|
||||||
}
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStrictMode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
val string
|
|
||||||
set bool
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{name: "unset", want: false},
|
|
||||||
{name: "true", val: "true", set: true, want: true},
|
|
||||||
{name: "1", val: "1", set: true, want: true},
|
|
||||||
{name: "false", val: "false", set: true, want: false},
|
|
||||||
{name: "invalid is false", val: "garbage", set: true, want: false},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Setenv(EnvStrict, tc.val)
|
|
||||||
if !tc.set {
|
|
||||||
os.Unsetenv(EnvStrict)
|
|
||||||
}
|
|
||||||
if got := strictMode(); got != tc.want {
|
|
||||||
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWindowsManagerDisableIdempotent(t *testing.T) {
|
|
||||||
m := &windowsManager{}
|
|
||||||
if err := m.Disable(); err != nil {
|
|
||||||
t.Fatalf("first Disable on fresh manager: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.Disable(); err != nil {
|
|
||||||
t.Fatalf("second Disable on fresh manager: %v", err)
|
|
||||||
}
|
|
||||||
if m.session != 0 {
|
|
||||||
t.Fatalf("session should remain zero, got %d", m.session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
|
|
||||||
t.Setenv(EnvDisable, "true")
|
|
||||||
|
|
||||||
m := &windowsManager{}
|
|
||||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
|
||||||
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
|
|
||||||
}
|
|
||||||
if m.session != 0 {
|
|
||||||
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
|
|
||||||
t.Setenv(EnvPorts, "")
|
|
||||||
|
|
||||||
m := &windowsManager{}
|
|
||||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
|
||||||
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
|
|
||||||
}
|
|
||||||
if m.session != 0 {
|
|
||||||
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
|
|
||||||
namePtr, err := windows.UTF16PtrFromString(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
descriptionPtr, err := windows.UTF16PtrFromString(description)
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &wtFwpmDisplayData0{
|
|
||||||
name: namePtr,
|
|
||||||
description: descriptionPtr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterWeight(weight uint8) wtFwpValue0 {
|
|
||||||
return wtFwpValue0{
|
|
||||||
_type: cFWP_UINT8,
|
|
||||||
value: uintptr(weight),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func wrapErr(err error) error {
|
|
||||||
var errno syscall.Errno
|
|
||||||
if !errors.As(err, &errno) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, file, line, ok := runtime.Caller(1)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("wfp error at unknown location: %w", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
|
|
||||||
}
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
|
|
||||||
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
|
|
||||||
* uses for its kill-switch DNS leak protection. We extend it with a
|
|
||||||
* configurable port set so we also cover :853 (DoT) and any future ports.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
|
|
||||||
// follow the authorized outbound flow.
|
|
||||||
|
|
||||||
// permitTunInterface installs a permit filter for any traffic whose local
|
|
||||||
// interface is the netbird tunnel.
|
|
||||||
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
|
|
||||||
cond := wtFwpmFilterCondition0{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_UINT64,
|
|
||||||
value: uintptr(unsafe.Pointer(&ifLUID)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
filter := wtFwpmFilter0{
|
|
||||||
providerKey: &base.provider,
|
|
||||||
subLayerKey: base.filters,
|
|
||||||
weight: filterWeight(weight),
|
|
||||||
numFilterConditions: 1,
|
|
||||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
|
||||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
|
||||||
}
|
|
||||||
|
|
||||||
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
|
|
||||||
}
|
|
||||||
|
|
||||||
// permitDaemonByAppID installs a permit filter matching the netbird daemon
|
|
||||||
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
|
|
||||||
// dedicated binary.
|
|
||||||
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
|
|
||||||
appID, err := daemonAppID(daemonExe)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
|
|
||||||
|
|
||||||
cond := wtFwpmFilterCondition0{
|
|
||||||
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_BYTE_BLOB_TYPE,
|
|
||||||
value: uintptr(unsafe.Pointer(appID)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
filter := wtFwpmFilter0{
|
|
||||||
providerKey: &base.provider,
|
|
||||||
subLayerKey: base.filters,
|
|
||||||
weight: filterWeight(weight),
|
|
||||||
numFilterConditions: 1,
|
|
||||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
|
||||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
|
||||||
}
|
|
||||||
|
|
||||||
return addOutboundFilters(session, &filter, "Permit netbird daemon")
|
|
||||||
}
|
|
||||||
|
|
||||||
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
|
|
||||||
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
|
|
||||||
// permitTunInterface.
|
|
||||||
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, port := range ports {
|
|
||||||
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
|
|
||||||
if !ip.IsValid() {
|
|
||||||
return fmt.Errorf("invalid address")
|
|
||||||
}
|
|
||||||
|
|
||||||
var addrCond wtFwpmFilterCondition0
|
|
||||||
var layer windows.GUID
|
|
||||||
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
|
|
||||||
var v6 wtFwpByteArray16
|
|
||||||
|
|
||||||
if ip.Is4() {
|
|
||||||
v4 := ip.As4()
|
|
||||||
addrCond = wtFwpmFilterCondition0{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_UINT32,
|
|
||||||
value: uintptr(binary.BigEndian.Uint32(v4[:])),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
|
|
||||||
} else {
|
|
||||||
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
|
|
||||||
addrCond = wtFwpmFilterCondition0{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_BYTE_ARRAY16_TYPE,
|
|
||||||
value: uintptr(unsafe.Pointer(&v6)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
|
|
||||||
}
|
|
||||||
|
|
||||||
conditions := [2]wtFwpmFilterCondition0{
|
|
||||||
addrCond,
|
|
||||||
{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_UINT16,
|
|
||||||
value: uintptr(port),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
filter := wtFwpmFilter0{
|
|
||||||
providerKey: &base.provider,
|
|
||||||
subLayerKey: base.filters,
|
|
||||||
weight: filterWeight(weight),
|
|
||||||
numFilterConditions: uint32(len(conditions)),
|
|
||||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
|
||||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
|
||||||
}
|
|
||||||
|
|
||||||
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
|
|
||||||
if err != nil {
|
|
||||||
return wrapErr(err)
|
|
||||||
}
|
|
||||||
filter.displayData = *display
|
|
||||||
filter.layerKey = layer
|
|
||||||
|
|
||||||
var filterID uint64
|
|
||||||
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
|
|
||||||
return wrapErr(err)
|
|
||||||
}
|
|
||||||
_ = v6
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// blockDNSPorts installs a deny filter for outbound traffic to each of the
|
|
||||||
// given remote ports over UDP or TCP. Per-port and per-layer failures are
|
|
||||||
// accumulated; partial coverage is preferred over zero coverage.
|
|
||||||
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, port := range ports {
|
|
||||||
if err := blockDNSPort(session, base, port, weight); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
|
|
||||||
conditions := [3]wtFwpmFilterCondition0{
|
|
||||||
{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_UINT16,
|
|
||||||
value: uintptr(port),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_UINT8,
|
|
||||||
value: uintptr(cIPPROTO_UDP),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
|
|
||||||
{
|
|
||||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
|
||||||
matchType: cFWP_MATCH_EQUAL,
|
|
||||||
conditionValue: wtFwpConditionValue0{
|
|
||||||
_type: cFWP_UINT8,
|
|
||||||
value: uintptr(cIPPROTO_TCP),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
filter := wtFwpmFilter0{
|
|
||||||
providerKey: &base.provider,
|
|
||||||
subLayerKey: base.filters,
|
|
||||||
weight: filterWeight(weight),
|
|
||||||
numFilterConditions: uint32(len(conditions)),
|
|
||||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
|
||||||
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
|
|
||||||
}
|
|
||||||
|
|
||||||
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
|
|
||||||
}
|
|
||||||
|
|
||||||
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
|
|
||||||
// connect layers. v4 and v6 are installed independently: failure on one
|
|
||||||
// layer does not abort the other, and the accumulated errors are returned.
|
|
||||||
// Partial coverage is preferred over zero coverage.
|
|
||||||
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
|
|
||||||
layers := [...]struct {
|
|
||||||
layer windows.GUID
|
|
||||||
label string
|
|
||||||
}{
|
|
||||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
|
|
||||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, l := range layers {
|
|
||||||
display, err := createWtFwpmDisplayData0(l.label, "")
|
|
||||||
if err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
filter.displayData = *display
|
|
||||||
filter.layerKey = l.layer
|
|
||||||
|
|
||||||
var filterID uint64
|
|
||||||
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Session lifecycle and the high-level Install/Close entry points adapted
|
|
||||||
* from wireguard-windows tunnel/firewall.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// installConfig is the input to installFilters.
|
|
||||||
type installConfig struct {
|
|
||||||
tunLUID uint64
|
|
||||||
daemonExe string
|
|
||||||
blockedPorts []uint16
|
|
||||||
// strict, when true, narrows the carve-out from "anything on tun" to
|
|
||||||
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
|
|
||||||
strict bool
|
|
||||||
virtualDNSIP netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
|
|
||||||
// for our session. Both are randomly generated per session.
|
|
||||||
type baseObjects struct {
|
|
||||||
provider windows.GUID
|
|
||||||
filters windows.GUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// installFilters opens a dynamic WFP session and installs the netbird DNS
|
|
||||||
// firewall filters. Returns a zero session on hard failure (session create,
|
|
||||||
// base objects); a non-zero session with a non-nil error is a partial install
|
|
||||||
// (some per-filter installs failed) and is safe to close.
|
|
||||||
func installFilters(cfg installConfig) (session uintptr, err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
// Dynamic session: kernel will clean up on process exit even
|
|
||||||
// if we leave the handle dangling here.
|
|
||||||
err = fmt.Errorf("panic in installFilters: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if len(cfg.blockedPorts) == 0 {
|
|
||||||
return 0, errors.New("dns firewall: no blocked ports configured")
|
|
||||||
}
|
|
||||||
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
|
|
||||||
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err = createSession()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
base, err := registerBaseObjects(session)
|
|
||||||
if err != nil {
|
|
||||||
_ = fwpmEngineClose0(session)
|
|
||||||
return 0, fmt.Errorf("register base objects: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
if cfg.strict {
|
|
||||||
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
|
|
||||||
}
|
|
||||||
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeSession tears down a WFP session previously opened by installFilters.
|
|
||||||
// All filters owned by the session are removed.
|
|
||||||
func closeSession(session uintptr) (err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("panic in closeSession: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if session == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := fwpmEngineClose0(session); err != nil {
|
|
||||||
return wrapErr(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createSession() (uintptr, error) {
|
|
||||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
|
|
||||||
if err != nil {
|
|
||||||
return 0, wrapErr(err)
|
|
||||||
}
|
|
||||||
session := wtFwpmSession0{
|
|
||||||
displayData: *displayData,
|
|
||||||
flags: cFWPM_SESSION_FLAG_DYNAMIC,
|
|
||||||
txnWaitTimeoutInMSec: windows.INFINITE,
|
|
||||||
}
|
|
||||||
var handle uintptr
|
|
||||||
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
|
|
||||||
return 0, wrapErr(err)
|
|
||||||
}
|
|
||||||
return handle, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerBaseObjects(session uintptr) (*baseObjects, error) {
|
|
||||||
bo := &baseObjects{}
|
|
||||||
var err error
|
|
||||||
if bo.provider, err = windows.GenerateGUID(); err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
if bo.filters, err = windows.GenerateGUID(); err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
provider := wtFwpmProvider0{
|
|
||||||
providerKey: bo.provider,
|
|
||||||
displayData: *displayData,
|
|
||||||
}
|
|
||||||
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
sublayer := wtFwpmSublayer0{
|
|
||||||
subLayerKey: bo.filters,
|
|
||||||
displayData: *subDisplay,
|
|
||||||
providerKey: &bo.provider,
|
|
||||||
weight: ^uint16(0),
|
|
||||||
}
|
|
||||||
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
return bo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
|
|
||||||
func daemonAppID(path string) (*wtFwpByteBlob, error) {
|
|
||||||
pathPtr, err := windows.UTF16PtrFromString(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
var appID *wtFwpByteBlob
|
|
||||||
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
|
|
||||||
return nil, wrapErr(err)
|
|
||||||
}
|
|
||||||
return appID, nil
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
|
|
||||||
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
|
|
||||||
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
|
|
||||||
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
|
|
||||||
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
|
|
||||||
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
|
|
||||||
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
|
|
||||||
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
|
|
||||||
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
|
|
||||||
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
|
|
||||||
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0
|
|
||||||
@@ -1,414 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import "golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
const (
|
|
||||||
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
|
|
||||||
|
|
||||||
wtFwpBitmapArray64_Size = 8
|
|
||||||
|
|
||||||
wtFwpByteArray16_Size = 16
|
|
||||||
|
|
||||||
wtFwpByteArray6_Size = 6
|
|
||||||
|
|
||||||
wtFwpmAction0_Size = 20
|
|
||||||
wtFwpmAction0_filterType_Offset = 4
|
|
||||||
|
|
||||||
wtFwpV4AddrAndMask_Size = 8
|
|
||||||
wtFwpV4AddrAndMask_mask_Offset = 4
|
|
||||||
|
|
||||||
wtFwpV6AddrAndMask_Size = 17
|
|
||||||
wtFwpV6AddrAndMask_prefixLength_Offset = 16
|
|
||||||
)
|
|
||||||
|
|
||||||
type wtFwpActionFlag uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
|
|
||||||
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
|
|
||||||
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWP_ACTION_TYPE defined in fwptypes.h
|
|
||||||
type wtFwpActionType uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
|
|
||||||
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
|
|
||||||
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
|
|
||||||
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
|
|
||||||
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
|
|
||||||
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
|
|
||||||
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
|
|
||||||
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
|
|
||||||
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWP_BYTE_BLOB defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
|
|
||||||
type wtFwpByteBlob struct {
|
|
||||||
size uint32
|
|
||||||
data *uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_MATCH_TYPE defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
|
|
||||||
type wtFwpMatchType uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWP_MATCH_EQUAL wtFwpMatchType = 0
|
|
||||||
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
|
|
||||||
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
|
|
||||||
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
|
|
||||||
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
|
|
||||||
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
|
|
||||||
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
|
|
||||||
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
|
|
||||||
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
|
|
||||||
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
|
|
||||||
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
|
|
||||||
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
|
|
||||||
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
|
|
||||||
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_ACTION0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
|
|
||||||
type wtFwpmAction0 struct {
|
|
||||||
_type wtFwpActionType
|
|
||||||
filterType windows.GUID // Windows type: GUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
|
||||||
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
|
|
||||||
Data1: 0x4cd62a49,
|
|
||||||
Data2: 0x59c3,
|
|
||||||
Data3: 0x4969,
|
|
||||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
|
|
||||||
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
|
|
||||||
Data1: 0xb235ae9a,
|
|
||||||
Data2: 0x1d64,
|
|
||||||
Data3: 0x49b8,
|
|
||||||
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
|
||||||
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
|
|
||||||
Data1: 0x3971ef2b,
|
|
||||||
Data2: 0x623e,
|
|
||||||
Data3: 0x4f9a,
|
|
||||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
|
|
||||||
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
|
|
||||||
Data1: 0x0c1ba1af,
|
|
||||||
Data2: 0x5765,
|
|
||||||
Data3: 0x453f,
|
|
||||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
|
|
||||||
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
|
|
||||||
Data1: 0xc35a604d,
|
|
||||||
Data2: 0xd22b,
|
|
||||||
Data3: 0x4e1a,
|
|
||||||
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
|
|
||||||
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
|
|
||||||
Data1: 0xd78e1e87,
|
|
||||||
Data2: 0x8644,
|
|
||||||
Data3: 0x4ea5,
|
|
||||||
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
|
|
||||||
}
|
|
||||||
|
|
||||||
// af043a0a-b34d-4f86-979c-c90371af6e66
|
|
||||||
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
|
|
||||||
Data1: 0xaf043a0a,
|
|
||||||
Data2: 0xb34d,
|
|
||||||
Data3: 0x4f86,
|
|
||||||
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
|
|
||||||
}
|
|
||||||
|
|
||||||
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
|
|
||||||
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
|
|
||||||
Data1: 0xd9ee00de,
|
|
||||||
Data2: 0xc1ef,
|
|
||||||
Data3: 0x4617,
|
|
||||||
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
|
|
||||||
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
|
|
||||||
)
|
|
||||||
|
|
||||||
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
|
|
||||||
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
|
|
||||||
Data1: 0x7bc43cbf,
|
|
||||||
Data2: 0x37ba,
|
|
||||||
Data3: 0x45f1,
|
|
||||||
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
|
|
||||||
}
|
|
||||||
|
|
||||||
type wtFwpmL2Flags uint32
|
|
||||||
|
|
||||||
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
|
|
||||||
|
|
||||||
var cFWPM_CONDITION_FLAGS = windows.GUID{
|
|
||||||
Data1: 0x632ce23b,
|
|
||||||
Data2: 0x5167,
|
|
||||||
Data3: 0x435c,
|
|
||||||
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
|
|
||||||
}
|
|
||||||
|
|
||||||
type wtFwpmFlags uint32
|
|
||||||
|
|
||||||
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
|
|
||||||
|
|
||||||
// Defined in fwpmtypes.h
|
|
||||||
type wtFwpmFilterFlags uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
|
|
||||||
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
|
|
||||||
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
|
|
||||||
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
|
|
||||||
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
|
|
||||||
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
|
|
||||||
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
|
|
||||||
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
|
|
||||||
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
|
|
||||||
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
|
|
||||||
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
|
|
||||||
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
|
|
||||||
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
|
|
||||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
|
|
||||||
Data1: 0xc38d57d1,
|
|
||||||
Data2: 0x05a7,
|
|
||||||
Data3: 0x4c33,
|
|
||||||
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
|
|
||||||
}
|
|
||||||
|
|
||||||
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
|
||||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
|
|
||||||
Data1: 0xe1cd9fe7,
|
|
||||||
Data2: 0xf4b5,
|
|
||||||
Data3: 0x4273,
|
|
||||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
|
|
||||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
|
|
||||||
Data1: 0x4a72393b,
|
|
||||||
Data2: 0x319f,
|
|
||||||
Data3: 0x44bc,
|
|
||||||
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
|
|
||||||
}
|
|
||||||
|
|
||||||
// a3b42c97-9f04-4672-b87e-cee9c483257f
|
|
||||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
|
|
||||||
Data1: 0xa3b42c97,
|
|
||||||
Data2: 0x9f04,
|
|
||||||
Data3: 0x4672,
|
|
||||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
|
|
||||||
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
|
||||||
Data1: 0x94c44912,
|
|
||||||
Data2: 0x9d6f,
|
|
||||||
Data3: 0x4ebf,
|
|
||||||
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
|
|
||||||
}
|
|
||||||
|
|
||||||
// d4220bd3-62ce-4f08-ae88-b56e8526df50
|
|
||||||
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
|
||||||
Data1: 0xd4220bd3,
|
|
||||||
Data2: 0x62ce,
|
|
||||||
Data3: 0x4f08,
|
|
||||||
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
|
|
||||||
type wtFwpBitmapArray64 struct {
|
|
||||||
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_BYTE_ARRAY6 defined in fwtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
|
|
||||||
type wtFwpByteArray6 struct {
|
|
||||||
byteArray6 [6]uint8 // Windows type: [6]UINT8
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_BYTE_ARRAY16 defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
|
|
||||||
type wtFwpByteArray16 struct {
|
|
||||||
byteArray16 [16]uint8 // Windows type [16]UINT8
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_CONDITION_VALUE0 defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
|
|
||||||
type wtFwpConditionValue0 wtFwpValue0
|
|
||||||
|
|
||||||
// FWP_DATA_TYPE defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
|
|
||||||
type wtFwpDataType uint
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWP_EMPTY wtFwpDataType = 0
|
|
||||||
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
|
|
||||||
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
|
|
||||||
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
|
|
||||||
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
|
|
||||||
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
|
|
||||||
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
|
|
||||||
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
|
|
||||||
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
|
|
||||||
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
|
|
||||||
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
|
|
||||||
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
|
|
||||||
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
|
|
||||||
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
|
|
||||||
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
|
|
||||||
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
|
|
||||||
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
|
|
||||||
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
|
|
||||||
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
|
|
||||||
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
|
|
||||||
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
|
|
||||||
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
|
|
||||||
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
|
|
||||||
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
|
|
||||||
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
|
|
||||||
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
|
|
||||||
type wtFwpV4AddrAndMask struct {
|
|
||||||
addr uint32
|
|
||||||
mask uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
|
|
||||||
type wtFwpV6AddrAndMask struct {
|
|
||||||
addr [16]uint8
|
|
||||||
prefixLength uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWP_VALUE0 defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
|
|
||||||
type wtFwpValue0 struct {
|
|
||||||
_type wtFwpDataType
|
|
||||||
value uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
|
|
||||||
type wtFwpmDisplayData0 struct {
|
|
||||||
name *uint16 // Windows type: *wchar_t
|
|
||||||
description *uint16 // Windows type: *wchar_t
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
|
|
||||||
type wtFwpmFilterCondition0 struct {
|
|
||||||
fieldKey windows.GUID // Windows type: GUID
|
|
||||||
matchType wtFwpMatchType
|
|
||||||
conditionValue wtFwpConditionValue0
|
|
||||||
}
|
|
||||||
|
|
||||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
|
|
||||||
type wtFwpProvider0 struct {
|
|
||||||
providerKey windows.GUID // Windows type: GUID
|
|
||||||
displayData wtFwpmDisplayData0
|
|
||||||
flags uint32
|
|
||||||
providerData wtFwpByteBlob
|
|
||||||
serviceName *uint16 // Windows type: *wchar_t
|
|
||||||
}
|
|
||||||
|
|
||||||
type wtFwpmSessionFlagsValue uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_SESSION0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
|
|
||||||
type wtFwpmSession0 struct {
|
|
||||||
sessionKey windows.GUID // Windows type: GUID
|
|
||||||
displayData wtFwpmDisplayData0
|
|
||||||
flags wtFwpmSessionFlagsValue // Windows type UINT32
|
|
||||||
txnWaitTimeoutInMSec uint32
|
|
||||||
processId uint32 // Windows type: DWORD
|
|
||||||
sid *windows.SID
|
|
||||||
username *uint16 // Windows type: *wchar_t
|
|
||||||
kernelMode uint8 // Windows type: BOOL
|
|
||||||
}
|
|
||||||
|
|
||||||
type wtFwpmSublayerFlags uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_SUBLAYER0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
|
|
||||||
type wtFwpmSublayer0 struct {
|
|
||||||
subLayerKey windows.GUID // Windows type: GUID
|
|
||||||
displayData wtFwpmDisplayData0
|
|
||||||
flags wtFwpmSublayerFlags
|
|
||||||
providerKey *windows.GUID // Windows type: *GUID
|
|
||||||
providerData wtFwpByteBlob
|
|
||||||
weight uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defined in rpcdce.h
|
|
||||||
type wtRpcCAuthN uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
|
|
||||||
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
|
|
||||||
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
|
|
||||||
type wtFwpmProvider0 struct {
|
|
||||||
providerKey windows.GUID
|
|
||||||
displayData wtFwpmDisplayData0
|
|
||||||
flags uint32
|
|
||||||
providerData wtFwpByteBlob
|
|
||||||
serviceName *uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type wtIPProto uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
cIPPROTO_ICMP wtIPProto = 1
|
|
||||||
cIPPROTO_ICMPV6 wtIPProto = 58
|
|
||||||
cIPPROTO_TCP wtIPProto = 6
|
|
||||||
cIPPROTO_UDP wtIPProto = 17
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFWP_ACTRL_MATCH_FILTER = 1
|
|
||||||
)
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
//go:build windows && (386 || arm)
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import "golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
const (
|
|
||||||
wtFwpByteBlob_Size = 8
|
|
||||||
wtFwpByteBlob_data_Offset = 4
|
|
||||||
|
|
||||||
wtFwpConditionValue0_Size = 8
|
|
||||||
wtFwpConditionValue0_uint8_Offset = 4
|
|
||||||
|
|
||||||
wtFwpmDisplayData0_Size = 8
|
|
||||||
wtFwpmDisplayData0_description_Offset = 4
|
|
||||||
|
|
||||||
wtFwpmFilter0_Size = 152
|
|
||||||
wtFwpmFilter0_displayData_Offset = 16
|
|
||||||
wtFwpmFilter0_flags_Offset = 24
|
|
||||||
wtFwpmFilter0_providerKey_Offset = 28
|
|
||||||
wtFwpmFilter0_providerData_Offset = 32
|
|
||||||
wtFwpmFilter0_layerKey_Offset = 40
|
|
||||||
wtFwpmFilter0_subLayerKey_Offset = 56
|
|
||||||
wtFwpmFilter0_weight_Offset = 72
|
|
||||||
wtFwpmFilter0_numFilterConditions_Offset = 80
|
|
||||||
wtFwpmFilter0_filterCondition_Offset = 84
|
|
||||||
wtFwpmFilter0_action_Offset = 88
|
|
||||||
wtFwpmFilter0_providerContextKey_Offset = 112
|
|
||||||
wtFwpmFilter0_reserved_Offset = 128
|
|
||||||
wtFwpmFilter0_filterID_Offset = 136
|
|
||||||
wtFwpmFilter0_effectiveWeight_Offset = 144
|
|
||||||
|
|
||||||
wtFwpmFilterCondition0_Size = 28
|
|
||||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
|
||||||
wtFwpmFilterCondition0_conditionValue_Offset = 20
|
|
||||||
|
|
||||||
wtFwpmSession0_Size = 48
|
|
||||||
wtFwpmSession0_displayData_Offset = 16
|
|
||||||
wtFwpmSession0_flags_Offset = 24
|
|
||||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
|
|
||||||
wtFwpmSession0_processId_Offset = 32
|
|
||||||
wtFwpmSession0_sid_Offset = 36
|
|
||||||
wtFwpmSession0_username_Offset = 40
|
|
||||||
wtFwpmSession0_kernelMode_Offset = 44
|
|
||||||
|
|
||||||
wtFwpmSublayer0_Size = 44
|
|
||||||
wtFwpmSublayer0_displayData_Offset = 16
|
|
||||||
wtFwpmSublayer0_flags_Offset = 24
|
|
||||||
wtFwpmSublayer0_providerKey_Offset = 28
|
|
||||||
wtFwpmSublayer0_providerData_Offset = 32
|
|
||||||
wtFwpmSublayer0_weight_Offset = 40
|
|
||||||
|
|
||||||
wtFwpProvider0_Size = 40
|
|
||||||
wtFwpProvider0_displayData_Offset = 16
|
|
||||||
wtFwpProvider0_flags_Offset = 24
|
|
||||||
wtFwpProvider0_providerData_Offset = 28
|
|
||||||
wtFwpProvider0_serviceName_Offset = 36
|
|
||||||
|
|
||||||
wtFwpTokenInformation_Size = 16
|
|
||||||
|
|
||||||
wtFwpValue0_Size = 8
|
|
||||||
wtFwpValue0_value_Offset = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
|
||||||
type wtFwpmFilter0 struct {
|
|
||||||
filterKey windows.GUID // Windows type: GUID
|
|
||||||
displayData wtFwpmDisplayData0
|
|
||||||
flags wtFwpmFilterFlags
|
|
||||||
providerKey *windows.GUID // Windows type: *GUID
|
|
||||||
providerData wtFwpByteBlob
|
|
||||||
layerKey windows.GUID // Windows type: GUID
|
|
||||||
subLayerKey windows.GUID // Windows type: GUID
|
|
||||||
weight wtFwpValue0
|
|
||||||
numFilterConditions uint32
|
|
||||||
filterCondition *wtFwpmFilterCondition0
|
|
||||||
action wtFwpmAction0
|
|
||||||
offset1 [4]byte // Layout correction field
|
|
||||||
providerContextKey windows.GUID // Windows type: GUID
|
|
||||||
reserved *windows.GUID // Windows type: *GUID
|
|
||||||
offset2 [4]byte // Layout correction field
|
|
||||||
filterID uint64
|
|
||||||
effectiveWeight wtFwpValue0
|
|
||||||
}
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
//go:build windows && (amd64 || arm64)
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import "golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
const (
|
|
||||||
wtFwpByteBlob_Size = 16
|
|
||||||
wtFwpByteBlob_data_Offset = 8
|
|
||||||
|
|
||||||
wtFwpConditionValue0_Size = 16
|
|
||||||
wtFwpConditionValue0_uint8_Offset = 8
|
|
||||||
|
|
||||||
wtFwpmDisplayData0_Size = 16
|
|
||||||
wtFwpmDisplayData0_description_Offset = 8
|
|
||||||
|
|
||||||
wtFwpmFilter0_Size = 200
|
|
||||||
wtFwpmFilter0_displayData_Offset = 16
|
|
||||||
wtFwpmFilter0_flags_Offset = 32
|
|
||||||
wtFwpmFilter0_providerKey_Offset = 40
|
|
||||||
wtFwpmFilter0_providerData_Offset = 48
|
|
||||||
wtFwpmFilter0_layerKey_Offset = 64
|
|
||||||
wtFwpmFilter0_subLayerKey_Offset = 80
|
|
||||||
wtFwpmFilter0_weight_Offset = 96
|
|
||||||
wtFwpmFilter0_numFilterConditions_Offset = 112
|
|
||||||
wtFwpmFilter0_filterCondition_Offset = 120
|
|
||||||
wtFwpmFilter0_action_Offset = 128
|
|
||||||
wtFwpmFilter0_providerContextKey_Offset = 152
|
|
||||||
wtFwpmFilter0_reserved_Offset = 168
|
|
||||||
wtFwpmFilter0_filterID_Offset = 176
|
|
||||||
wtFwpmFilter0_effectiveWeight_Offset = 184
|
|
||||||
|
|
||||||
wtFwpmFilterCondition0_Size = 40
|
|
||||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
|
||||||
wtFwpmFilterCondition0_conditionValue_Offset = 24
|
|
||||||
|
|
||||||
wtFwpmSession0_Size = 72
|
|
||||||
wtFwpmSession0_displayData_Offset = 16
|
|
||||||
wtFwpmSession0_flags_Offset = 32
|
|
||||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
|
|
||||||
wtFwpmSession0_processId_Offset = 40
|
|
||||||
wtFwpmSession0_sid_Offset = 48
|
|
||||||
wtFwpmSession0_username_Offset = 56
|
|
||||||
wtFwpmSession0_kernelMode_Offset = 64
|
|
||||||
|
|
||||||
wtFwpmSublayer0_Size = 72
|
|
||||||
wtFwpmSublayer0_displayData_Offset = 16
|
|
||||||
wtFwpmSublayer0_flags_Offset = 32
|
|
||||||
wtFwpmSublayer0_providerKey_Offset = 40
|
|
||||||
wtFwpmSublayer0_providerData_Offset = 48
|
|
||||||
wtFwpmSublayer0_weight_Offset = 64
|
|
||||||
|
|
||||||
wtFwpProvider0_Size = 64
|
|
||||||
wtFwpProvider0_displayData_Offset = 16
|
|
||||||
wtFwpProvider0_flags_Offset = 32
|
|
||||||
wtFwpProvider0_providerData_Offset = 40
|
|
||||||
wtFwpProvider0_serviceName_Offset = 56
|
|
||||||
|
|
||||||
wtFwpValue0_Size = 16
|
|
||||||
wtFwpValue0_value_Offset = 8
|
|
||||||
)
|
|
||||||
|
|
||||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
|
||||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
|
||||||
type wtFwpmFilter0 struct {
|
|
||||||
filterKey windows.GUID // Windows type: GUID
|
|
||||||
displayData wtFwpmDisplayData0
|
|
||||||
flags wtFwpmFilterFlags // Windows type: UINT32
|
|
||||||
providerKey *windows.GUID // Windows type: *GUID
|
|
||||||
providerData wtFwpByteBlob
|
|
||||||
layerKey windows.GUID // Windows type: GUID
|
|
||||||
subLayerKey windows.GUID // Windows type: GUID
|
|
||||||
weight wtFwpValue0
|
|
||||||
numFilterConditions uint32
|
|
||||||
filterCondition *wtFwpmFilterCondition0
|
|
||||||
action wtFwpmAction0
|
|
||||||
offset1 [4]byte // Layout correction field
|
|
||||||
providerContextKey windows.GUID // Windows type: GUID
|
|
||||||
reserved *windows.GUID // Windows type: *GUID
|
|
||||||
filterID uint64
|
|
||||||
effectiveWeight wtFwpValue0
|
|
||||||
}
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package dnsfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
errERROR_EINVAL error = syscall.EINVAL
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return errERROR_EINVAL
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
|
||||||
|
|
||||||
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
|
|
||||||
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
|
|
||||||
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
|
|
||||||
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
|
|
||||||
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
|
|
||||||
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
|
|
||||||
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
|
|
||||||
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
|
|
||||||
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
|
|
||||||
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
|
|
||||||
)
|
|
||||||
|
|
||||||
func fwpmEngineClose0(engineHandle uintptr) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmFreeMemory0(p unsafe.Pointer) {
|
|
||||||
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||||
)
|
)
|
||||||
@@ -72,7 +71,6 @@ type registryConfigurator struct {
|
|||||||
routingAll bool
|
routingAll bool
|
||||||
gpo bool
|
gpo bool
|
||||||
nrptEntryCount int
|
nrptEntryCount int
|
||||||
dnsFirewall dnsfw.Manager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
@@ -92,9 +90,8 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
configurator := ®istryConfigurator{
|
configurator := ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
gpo: useGPO,
|
gpo: useGPO,
|
||||||
dnsFirewall: dnsfw.New(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := configurator.configureInterface(); err != nil {
|
if err := configurator.configureInterface(); err != nil {
|
||||||
@@ -172,8 +169,16 @@ func (r *registryConfigurator) disableWINSForInterface() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
if err := r.applyRouteAll(config); err != nil {
|
if config.RouteAll {
|
||||||
return err
|
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||||
|
return fmt.Errorf("add dns setup: %w", err)
|
||||||
|
}
|
||||||
|
} else if r.routingAll {
|
||||||
|
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||||
|
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||||
|
}
|
||||||
|
r.routingAll = false
|
||||||
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.updateState(stateManager)
|
r.updateState(stateManager)
|
||||||
@@ -215,35 +220,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
|
|
||||||
if config.RouteAll {
|
|
||||||
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
|
|
||||||
return fmt.Errorf("dns firewall: %w", err)
|
|
||||||
}
|
|
||||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
|
||||||
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
|
|
||||||
if dErr := r.dnsFirewall.Disable(); dErr != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.dnsFirewall.Disable(); err != nil {
|
|
||||||
log.Errorf("disable dns firewall: %v", err)
|
|
||||||
}
|
|
||||||
if !r.routingAll {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
|
||||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
|
||||||
}
|
|
||||||
r.routingAll = false
|
|
||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
Guid: r.guid,
|
Guid: r.guid,
|
||||||
@@ -430,10 +406,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.dnsFirewall.Disable(); err != nil {
|
|
||||||
log.Errorf("disable dns firewall: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go r.flushDNSCache()
|
go r.flushDNSCache()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||||
@@ -36,9 +34,8 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
cfg := ®istryConfigurator{
|
cfg := ®istryConfigurator{
|
||||||
guid: testGUID,
|
guid: testGUID,
|
||||||
gpo: false,
|
gpo: false,
|
||||||
dnsFirewall: dnsfw.New(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||||
@@ -137,9 +134,8 @@ func TestNRPTDomainBatching(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
cfg := ®istryConfigurator{
|
cfg := ®istryConfigurator{
|
||||||
guid: testGUID,
|
guid: testGUID,
|
||||||
gpo: false,
|
gpo: false,
|
||||||
dnsFirewall: dnsfw.New(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
|||||||
@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConnector updates an existing connector in Dex storage.
|
// UpdateConnector updates an existing connector in Dex storage.
|
||||||
// It merges incoming updates with existing values to prevent data loss on partial updates.
|
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
|
||||||
|
// redirectURI) onto the stored connector config, and updates the connector name
|
||||||
|
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
|
||||||
|
// partial updates preserve create-time defaults such as scopes, claimMapping,
|
||||||
|
// and userIDKey.
|
||||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||||
oldCfg, err := p.parseStorageConnector(old)
|
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
|
||||||
if err != nil {
|
return storage.Connector{}, errors.New("connector type change not allowed")
|
||||||
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mergeConnectorConfig(cfg, oldCfg)
|
configData, err := overlayConnectorConfig(old.Config, cfg)
|
||||||
|
|
||||||
storageConn, err := p.buildStorageConnector(cfg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
|
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
|
||||||
}
|
}
|
||||||
return storageConn, nil
|
|
||||||
|
name := cfg.Name
|
||||||
|
if name == "" {
|
||||||
|
name = old.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Connector{
|
||||||
|
ID: cfg.ID,
|
||||||
|
Type: old.Type,
|
||||||
|
Name: name,
|
||||||
|
Config: configData,
|
||||||
|
}, nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("failed to update connector: %w", err)
|
return fmt.Errorf("failed to update connector: %w", err)
|
||||||
}
|
}
|
||||||
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeConnectorConfig preserves existing values for empty fields in the update.
|
// overlayConnectorConfig writes only the user-mutable fields onto the existing
|
||||||
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
|
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
|
||||||
if cfg.ClientSecret == "" {
|
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
|
||||||
cfg.ClientSecret = oldCfg.ClientSecret
|
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
|
||||||
|
var m map[string]any
|
||||||
|
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
if cfg.RedirectURI == "" {
|
if cfg.Issuer != "" {
|
||||||
cfg.RedirectURI = oldCfg.RedirectURI
|
m["issuer"] = cfg.Issuer
|
||||||
}
|
}
|
||||||
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
|
if cfg.ClientID != "" {
|
||||||
cfg.Issuer = oldCfg.Issuer
|
m["clientID"] = cfg.ClientID
|
||||||
}
|
}
|
||||||
if cfg.ClientID == "" {
|
if cfg.ClientSecret != "" {
|
||||||
cfg.ClientID = oldCfg.ClientID
|
m["clientSecret"] = cfg.ClientSecret
|
||||||
}
|
}
|
||||||
if cfg.Name == "" {
|
if cfg.RedirectURI != "" {
|
||||||
cfg.Name = oldCfg.Name
|
m["redirectURI"] = cfg.RedirectURI
|
||||||
}
|
}
|
||||||
|
return encodeConnectorConfig(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteConnector removes a connector from Dex storage.
|
// DeleteConnector removes a connector from Dex storage.
|
||||||
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
|||||||
oidcConfig["getUserInfo"] = true
|
oidcConfig["getUserInfo"] = true
|
||||||
case "entra":
|
case "entra":
|
||||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||||
|
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
|
||||||
|
// Entra issues sub as a per-app pairwise identifier that does not match
|
||||||
|
// the stable Object ID.
|
||||||
|
oidcConfig["userIDKey"] = "oid"
|
||||||
case "okta":
|
case "okta":
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "pocketid":
|
case "pocketid":
|
||||||
|
|||||||
205
idp/dex/connector_test.go
Normal file
205
idp/dex/connector_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/sql"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestProvider(t *testing.T) (*Provider, func()) {
|
||||||
|
t.Helper()
|
||||||
|
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return &Provider{storage: s, logger: logger}, func() {
|
||||||
|
_ = s.Close()
|
||||||
|
_ = os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
|
||||||
|
cfg := &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}
|
||||||
|
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &m))
|
||||||
|
|
||||||
|
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
|
||||||
|
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
|
||||||
|
// ensures the Entra userIDKey override does not leak into other OIDC providers,
|
||||||
|
// which already use a stable sub claim.
|
||||||
|
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
|
||||||
|
t.Run(typ, func(t *testing.T) {
|
||||||
|
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &m))
|
||||||
|
_, ok := m["userIDKey"]
|
||||||
|
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
created, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "old-secret",
|
||||||
|
RedirectURI: "https://example.com/oauth2/callback",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "entra-test", created.ID)
|
||||||
|
|
||||||
|
// Rotate only the client secret.
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Type: "entra",
|
||||||
|
ClientSecret: "new-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
|
||||||
|
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
|
||||||
|
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
|
||||||
|
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
|
||||||
|
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
|
||||||
|
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Seed a connector directly into storage without userIDKey
|
||||||
|
preFixConfig, err := json.Marshal(map[string]any{
|
||||||
|
"issuer": "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
"clientID": "client-id",
|
||||||
|
"clientSecret": "old-secret",
|
||||||
|
"redirectURI": "https://example.com/oauth2/callback",
|
||||||
|
"scopes": []string{"openid", "profile", "email"},
|
||||||
|
"claimMapping": map[string]string{"email": "preferred_username"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
|
||||||
|
ID: "entra-prefix",
|
||||||
|
Type: "oidc",
|
||||||
|
Name: "Entra",
|
||||||
|
Config: preFixConfig,
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Rotate client secret via UpdateConnector.
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-prefix",
|
||||||
|
Type: "entra",
|
||||||
|
ClientSecret: "new-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
|
||||||
|
assert.Equal(t, "new-secret", m["clientSecret"])
|
||||||
|
_, has := m["userIDKey"]
|
||||||
|
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "secret",
|
||||||
|
RedirectURI: "https://example.com/oauth2/callback",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempt to switch the connector to okta.
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Type: "okta",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "connector type change not allowed")
|
||||||
|
|
||||||
|
// stored connector type/config unchanged after the rejected update.
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "oidc", conn.Type)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
assert.Equal(t, "oid", m["userIDKey"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/old/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "secret",
|
||||||
|
RedirectURI: "https://example.com/oauth2/callback",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/new/v2.0",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
|
||||||
|
}
|
||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
|
|||||||
// Store for PKCE verifiers
|
// Store for PKCE verifiers
|
||||||
pkceVerifierStore *PKCEVerifierStore
|
pkceVerifierStore *PKCEVerifierStore
|
||||||
|
|
||||||
|
// tokenTTL is the lifetime of one-time tokens generated for proxy
|
||||||
|
// authentication. Defaults to defaultProxyTokenTTL when zero.
|
||||||
|
tokenTTL time.Duration
|
||||||
|
|
||||||
|
// snapshotBatchSize is the number of mappings per gRPC message during
|
||||||
|
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
|
||||||
|
snapshotBatchSize int
|
||||||
|
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
const pkceVerifierTTL = 10 * time.Minute
|
const pkceVerifierTTL = 10 * time.Minute
|
||||||
|
|
||||||
|
const defaultProxyTokenTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
const defaultSnapshotBatchSize = 500
|
||||||
|
|
||||||
|
func snapshotBatchSizeFromEnv() int {
|
||||||
|
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultSnapshotBatchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyTokenTTL returns the configured token TTL or the default when unset.
|
||||||
|
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
|
||||||
|
if s.tokenTTL > 0 {
|
||||||
|
return s.tokenTTL
|
||||||
|
}
|
||||||
|
return defaultProxyTokenTTL
|
||||||
|
}
|
||||||
|
|
||||||
// proxyConnection represents a connected proxy
|
// proxyConnection represents a connected proxy
|
||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
|||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
|
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
go s.cleanupStaleProxies(ctx)
|
go s.cleanupStaleProxies(ctx)
|
||||||
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
|
||||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register proxy in database with capabilities
|
// Register proxy in database with capabilities
|
||||||
var caps *proxy.Capabilities
|
var caps *proxy.Capabilities
|
||||||
if c := req.GetCapabilities(); c != nil {
|
if c := req.GetCapabilities(); c != nil {
|
||||||
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||||
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
cancel()
|
||||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
|
||||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
|
||||||
}
|
|
||||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||||
|
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||||
|
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||||
|
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||||
|
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
"session_id": sessionID,
|
"session_id": sessionID,
|
||||||
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
|
||||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
|
||||||
go s.sender(conn, errChan)
|
|
||||||
|
|
||||||
go s.heartbeat(connCtx, proxyRecord)
|
go s.heartbeat(connCtx, proxyRecord)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send mappings in batches to reduce per-message gRPC overhead while
|
||||||
|
// staying well within the default 4 MB message size limit.
|
||||||
|
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||||
|
end := i + s.snapshotBatchSize
|
||||||
|
if end > len(mappings) {
|
||||||
|
end = len(mappings)
|
||||||
|
}
|
||||||
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: mappings[i:end],
|
||||||
|
InitialSyncComplete: end == len(mappings),
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("send snapshot batch: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(mappings) == 0 {
|
if len(mappings) == 0 {
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
InitialSyncComplete: true,
|
InitialSyncComplete: true,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("send snapshot completion: %w", err)
|
return fmt.Errorf("send snapshot completion: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, m := range mappings {
|
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
|
||||||
Mapping: []*proto.ProxyMapping{m},
|
|
||||||
InitialSyncComplete: i == len(mappings)-1,
|
|
||||||
}); err != nil {
|
|
||||||
return fmt.Errorf("send proxy mapping: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||||
"service": service.Name,
|
|
||||||
"account": service.AccountID,
|
|
||||||
}).WithError(err).Error("failed to generate auth token for snapshot")
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||||
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
|||||||
conn := value.(*proxyConnection)
|
conn := value.(*proxyConnection)
|
||||||
resp := s.perProxyMessage(update, conn.proxyID)
|
resp := s.perProxyMessage(update, conn.proxyID)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
|
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||||
|
conn.cancel()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- resp:
|
case conn.sendChan <- resp:
|
||||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||||
default:
|
default:
|
||||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||||
|
conn.cancel()
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
|||||||
}
|
}
|
||||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
|
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||||
|
conn.cancel()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- msg:
|
case conn.sendChan <- msg:
|
||||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||||
default:
|
default:
|
||||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||||
|
conn.cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
|||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
// create/update operations. For delete operations the original mapping is
|
// create/update operations. For delete operations the original mapping is
|
||||||
// used unchanged because proxies do not need to authenticate for removal.
|
// used unchanged because proxies do not need to authenticate for removal.
|
||||||
// Returns nil if token generation fails (the proxy should be skipped).
|
// Returns nil if token generation fails; the caller must disconnect the
|
||||||
|
// proxy so it can resync via a fresh snapshot on reconnect.
|
||||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||||
for _, mapping := range update.Mapping {
|
for _, mapping := range update.Mapping {
|
||||||
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// recordingStream captures all messages sent via Send so tests can inspect
|
||||||
|
// batching behaviour without a real gRPC transport.
|
||||||
|
type recordingStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
messages []*proto.GetMappingUpdateResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||||
|
s.messages = append(s.messages, m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *recordingStream) Context() context.Context { return context.Background() }
|
||||||
|
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *recordingStream) SetTrailer(metadata.MD) {}
|
||||||
|
func (s *recordingStream) SendMsg(any) error { return nil }
|
||||||
|
func (s *recordingStream) RecvMsg(any) error { return nil }
|
||||||
|
|
||||||
|
// makeServices creates n enabled services assigned to the given cluster.
|
||||||
|
func makeServices(n int, cluster string) []*rpservice.Service {
|
||||||
|
services := make([]*rpservice.Service, n)
|
||||||
|
for i := range n {
|
||||||
|
services[i] = &rpservice.Service{
|
||||||
|
ID: fmt.Sprintf("svc-%d", i),
|
||||||
|
AccountID: "acct-1",
|
||||||
|
Name: fmt.Sprintf("svc-%d", i),
|
||||||
|
Domain: fmt.Sprintf("svc-%d.example.com", i),
|
||||||
|
ProxyCluster: cluster,
|
||||||
|
Enabled: true,
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return services
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
|
||||||
|
t.Helper()
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
|
||||||
|
snapshotBatchSize: batchSize,
|
||||||
|
}
|
||||||
|
s.SetProxyController(newTestProxyController())
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_BatchesMappings(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 3
|
||||||
|
const totalServices = 7 // 3 + 3 + 1
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{
|
||||||
|
proxyID: "proxy-a",
|
||||||
|
address: cluster,
|
||||||
|
stream: stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendSnapshot(context.Background(), conn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Expect ceil(7/3) = 3 messages
|
||||||
|
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||||
|
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||||
|
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[2].Mapping, 1)
|
||||||
|
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
|
||||||
|
|
||||||
|
// Verify all service IDs are present exactly once
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, msg := range stream.messages {
|
||||||
|
for _, m := range msg.Mapping {
|
||||||
|
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
|
||||||
|
seen[m.Id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Len(t, seen, totalServices)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 3
|
||||||
|
const totalServices = 6 // exactly 2 batches
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Len(t, stream.messages, 2)
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||||
|
assert.False(t, stream.messages[0].InitialSyncComplete)
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||||
|
assert.True(t, stream.messages[1].InitialSyncComplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_SingleBatch(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 100
|
||||||
|
const totalServices = 5
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
|
||||||
|
assert.Len(t, stream.messages[0].Mapping, totalServices)
|
||||||
|
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, 500)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
|
||||||
|
assert.Empty(t, stream.messages[0].Mapping)
|
||||||
|
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||||
|
}
|
||||||
@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
|||||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: clusterAddr,
|
address: clusterAddr,
|
||||||
capabilities: caps,
|
capabilities: caps,
|
||||||
sendChan: ch,
|
sendChan: ch,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
|
||||||
|
|||||||
@@ -1040,6 +1040,13 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
|
|||||||
|
|
||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
if am.isCacheFresh(ctx, accountUsers, data) {
|
if am.isCacheFresh(ctx, accountUsers, data) {
|
||||||
|
// Catch the silent vacuous-fresh case: empty accountUsers map + empty cache data
|
||||||
|
// → isCacheFresh returns true without iterating, skipping refreshCache,
|
||||||
|
// returning empty data. This matters when the account is mostly integration
|
||||||
|
// (SCIM) users and the InternalCache has been flushed.
|
||||||
|
if len(accountUsers) == 0 && len(data) == 0 {
|
||||||
|
log.WithContext(ctx).Warnf("lookupCache VACUOUS FRESH: accountUsers map is empty AND cache is empty for account %s — returning empty data without triggering loadAccount", accountID)
|
||||||
|
}
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,11 +7,8 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
@@ -45,11 +42,6 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
|
|||||||
|
|
||||||
// getAllCountries retrieves a list of all countries
|
// getAllCountries retrieves a list of all countries
|
||||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := l.authenticateUser(r); err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if l.geolocationManager == nil {
|
if l.geolocationManager == nil {
|
||||||
// TODO: update error message to include geo db self hosted doc link when ready
|
// TODO: update error message to include geo db self hosted doc link when ready
|
||||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||||
@@ -71,11 +63,6 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
|
|||||||
|
|
||||||
// getCitiesByCountry retrieves a list of cities based on the given country code
|
// getCitiesByCountry retrieves a list of cities based on the given country code
|
||||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := l.authenticateUser(r); err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
countryCode := vars["country"]
|
countryCode := vars["country"]
|
||||||
if !countryCodeRegex.MatchString(countryCode) {
|
if !countryCodeRegex.MatchString(countryCode) {
|
||||||
@@ -102,27 +89,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
|
|||||||
util.WriteJSONObject(r.Context(), w, cities)
|
util.WriteJSONObject(r.Context(), w, cities)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
|
||||||
|
|
||||||
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !allowed {
|
|
||||||
return status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toCountryResponse(country geolocation.Country) api.Country {
|
func toCountryResponse(country geolocation.Country) api.Country {
|
||||||
return api.Country{
|
return api.Country{
|
||||||
CountryName: country.CountryName,
|
CountryName: country.CountryName,
|
||||||
|
|||||||
@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
|||||||
_, plainToken, err := GenerateInviteToken()
|
_, plainToken, err := GenerateInviteToken()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Modify one character in the secret part
|
replacement := "X"
|
||||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
if plainToken[5] == 'X' {
|
||||||
|
replacement = "Y"
|
||||||
|
}
|
||||||
|
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
|
||||||
err = ValidateInviteToken(modifiedToken)
|
err = ValidateInviteToken(modifiedToken)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1061,15 +1061,28 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
|||||||
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
idpManagerNil := isNil(am.idpManager)
|
||||||
|
idpManagerEmbedded := !idpManagerNil && IsEmbeddedIdp(am.idpManager)
|
||||||
|
|
||||||
userInfosMap := make(map[string]*types.UserInfo)
|
userInfosMap := make(map[string]*types.UserInfo)
|
||||||
|
|
||||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||||
if len(queriedUsers) == 0 {
|
if len(queriedUsers) == 0 {
|
||||||
|
var earlyReturnEmpty int
|
||||||
|
var earlyReturnEmptySamples []string
|
||||||
for _, accountUser := range accountUsers {
|
for _, accountUser := range accountUsers {
|
||||||
info, err := accountUser.ToUserInfo(nil)
|
info, err := accountUser.ToUserInfo(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if !accountUser.IsServiceUser && (info.Email == "" || info.Name == "") {
|
||||||
|
earlyReturnEmpty++
|
||||||
|
if len(earlyReturnEmptySamples) < 50 {
|
||||||
|
earlyReturnEmptySamples = append(earlyReturnEmptySamples,
|
||||||
|
fmt.Sprintf("%s(issued=%s,db.email=%q,db.name=%q)",
|
||||||
|
accountUser.Id, accountUser.Issued, accountUser.Email, accountUser.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
// Try to decode Dex user ID to extract the IdP ID (connector ID)
|
// Try to decode Dex user ID to extract the IdP ID (connector ID)
|
||||||
if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" {
|
if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" {
|
||||||
info.IdPID = connectorID
|
info.IdPID = connectorID
|
||||||
@@ -1077,17 +1090,61 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
|||||||
userInfosMap[accountUser.Id] = info
|
userInfosMap[accountUser.Id] = info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Warnf("BuildUserInfosForAccount EARLY RETURN: queriedUsers empty, returning %d users with DB-only data (idpManagerNil=%v, idpManagerEmbedded=%v). %d non-service users have empty email/name in DB. Samples: %v",
|
||||||
|
len(accountUsers), idpManagerNil, idpManagerEmbedded, earlyReturnEmpty, earlyReturnEmptySamples)
|
||||||
|
|
||||||
|
// Same canonical-truth final scan, also on the early-return path
|
||||||
|
var finalEmptyEmail, finalEmptyName int
|
||||||
|
var finalEmptySamples []string
|
||||||
|
for id, info := range userInfosMap {
|
||||||
|
if info.IsServiceUser {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.Email == "" {
|
||||||
|
finalEmptyEmail++
|
||||||
|
}
|
||||||
|
if info.Name == "" {
|
||||||
|
finalEmptyName++
|
||||||
|
}
|
||||||
|
if (info.Email == "" || info.Name == "") && len(finalEmptySamples) < 200 {
|
||||||
|
finalEmptySamples = append(finalEmptySamples,
|
||||||
|
fmt.Sprintf("%s(email=%q,name=%q,issued=%s)", id, info.Email, info.Name, info.Issued))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if finalEmptyEmail > 0 || finalEmptyName > 0 {
|
||||||
|
log.WithContext(ctx).Warnf("BuildUserInfosForAccount FINAL (early-return path): returning %d UserInfo entries — %d with empty email, %d with empty name. Samples: %v",
|
||||||
|
len(userInfosMap), finalEmptyEmail, finalEmptyName, finalEmptySamples)
|
||||||
|
}
|
||||||
|
|
||||||
return userInfosMap, nil
|
return userInfosMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cacheHitEmpty, fallbackMiss int
|
||||||
|
var cacheHitEmptySamples, fallbackMissSamples []string
|
||||||
for _, localUser := range accountUsers {
|
for _, localUser := range accountUsers {
|
||||||
var info *types.UserInfo
|
var info *types.UserInfo
|
||||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||||
|
if !localUser.IsServiceUser && (queriedUser.Email == "" || queriedUser.Name == "") {
|
||||||
|
cacheHitEmpty++
|
||||||
|
if len(cacheHitEmptySamples) < 50 {
|
||||||
|
cacheHitEmptySamples = append(cacheHitEmptySamples,
|
||||||
|
fmt.Sprintf("%s(cache.email=%q,cache.name=%q,db.email=%q,db.name=%q)",
|
||||||
|
localUser.Id, queriedUser.Email, queriedUser.Name, localUser.Email, localUser.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
info, err = localUser.ToUserInfo(queriedUser)
|
info, err = localUser.ToUserInfo(queriedUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if !localUser.IsServiceUser {
|
||||||
|
fallbackMiss++
|
||||||
|
if len(fallbackMissSamples) < 50 {
|
||||||
|
fallbackMissSamples = append(fallbackMissSamples,
|
||||||
|
fmt.Sprintf("%s(issued=%s,db.email=%q,db.name=%q)",
|
||||||
|
localUser.Id, localUser.Issued, localUser.Email, localUser.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
name := ""
|
name := ""
|
||||||
if localUser.IsServiceUser {
|
if localUser.IsServiceUser {
|
||||||
name = localUser.ServiceUserName
|
name = localUser.ServiceUserName
|
||||||
@@ -1111,6 +1168,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
|||||||
userInfosMap[info.ID] = info
|
userInfosMap[info.ID] = info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cacheHitEmpty > 0 {
|
||||||
|
log.WithContext(ctx).Warnf("BuildUserInfosForAccount: %d users found in cache with empty email or name (cache pollution). Samples: %v",
|
||||||
|
cacheHitEmpty, cacheHitEmptySamples)
|
||||||
|
}
|
||||||
|
if fallbackMiss > 0 {
|
||||||
|
log.WithContext(ctx).Warnf("BuildUserInfosForAccount: %d non-service users missed both caches (will get empty Name in API response from fallback). Samples: %v",
|
||||||
|
fallbackMiss, fallbackMissSamples)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Canonical-truth log: scan what we are actually about to return to the handler.
|
||||||
|
// This catches empties from any code path (early-return, cache-hit-empty, fallback-miss,
|
||||||
|
// or anything we haven't identified yet).
|
||||||
|
var finalEmptyEmail, finalEmptyName int
|
||||||
|
var finalEmptySamples []string
|
||||||
|
for id, info := range userInfosMap {
|
||||||
|
if info.IsServiceUser {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.Email == "" {
|
||||||
|
finalEmptyEmail++
|
||||||
|
}
|
||||||
|
if info.Name == "" {
|
||||||
|
finalEmptyName++
|
||||||
|
}
|
||||||
|
if (info.Email == "" || info.Name == "") && len(finalEmptySamples) < 200 {
|
||||||
|
finalEmptySamples = append(finalEmptySamples,
|
||||||
|
fmt.Sprintf("%s(email=%q,name=%q,issued=%s)", id, info.Email, info.Name, info.Issued))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if finalEmptyEmail > 0 || finalEmptyName > 0 {
|
||||||
|
log.WithContext(ctx).Warnf("BuildUserInfosForAccount FINAL: returning %d UserInfo entries — %d with empty email, %d with empty name. Samples: %v",
|
||||||
|
len(userInfosMap), finalEmptyEmail, finalEmptyName, finalEmptySamples)
|
||||||
|
}
|
||||||
|
|
||||||
return userInfosMap, nil
|
return userInfosMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
|
||||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, m := range msg.GetMapping() {
|
for _, m := range msg.GetMapping() {
|
||||||
mappingsByID[m.GetId()] = m
|
mappingsByID[m.GetId()] = m
|
||||||
}
|
}
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should receive 2 mappings total
|
// Should receive 2 mappings total
|
||||||
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Receive all mappings - server sends each mapping individually
|
|
||||||
mappings := make([]*proto.ProxyMapping, 0)
|
mappings := make([]*proto.ProxyMapping, 0)
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mappings = append(mappings, msg.GetMapping()...)
|
mappings = append(mappings, msg.GetMapping()...)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should receive the 2 mappings matching the cluster
|
// Should receive the 2 mappings matching the cluster
|
||||||
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
|||||||
clusterAddress := "test.proxy.io"
|
clusterAddress := "test.proxy.io"
|
||||||
proxyID := "test-proxy-reconnect"
|
proxyID := "test-proxy-reconnect"
|
||||||
|
|
||||||
// Helper to receive all mappings from a stream
|
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
|
||||||
var mappings []*proto.ProxyMapping
|
var mappings []*proto.ProxyMapping
|
||||||
for i := 0; i < count; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mappings = append(mappings, msg.GetMapping()...)
|
mappings = append(mappings, msg.GetMapping()...)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return mappings
|
return mappings
|
||||||
}
|
}
|
||||||
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
firstMappings := receiveMappings(stream1, 2)
|
firstMappings := receiveMappings(stream1)
|
||||||
cancel1()
|
cancel1()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
secondMappings := receiveMappings(stream2, 2)
|
secondMappings := receiveMappings(stream2)
|
||||||
|
|
||||||
// Should receive the same mappings
|
// Should receive the same mappings
|
||||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||||
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to receive and apply all mappings
|
|
||||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
applyMappings(msg.GetMapping())
|
applyMappings(msg.GetMapping())
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Receive all mappings - server sends each mapping individually
|
|
||||||
count := 0
|
count := 0
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
count += len(msg.GetMapping())
|
count += len(msg.GetMapping())
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
@@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
_, err := stream1.Recv()
|
msg, err := stream1.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||||
@@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
_, err := stream2.Recv()
|
msg, err := stream2.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cancel1()
|
cancel1()
|
||||||
|
|||||||
@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
|||||||
operation := func() error {
|
operation := func() error {
|
||||||
s.Logger.Debug("connecting to management mapping stream")
|
s.Logger.Debug("connecting to management mapping stream")
|
||||||
|
|
||||||
|
initialSyncDone = false
|
||||||
|
|
||||||
if s.healthChecker != nil {
|
if s.healthChecker != nil {
|
||||||
s.healthChecker.SetManagementConnected(false)
|
s.healthChecker.SetManagementConnected(false)
|
||||||
}
|
}
|
||||||
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var snapshotIDs map[types.ServiceID]struct{}
|
||||||
|
if !*initialSyncDone {
|
||||||
|
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Check for context completion to gracefully shutdown.
|
// Check for context completion to gracefully shutdown.
|
||||||
select {
|
select {
|
||||||
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
|||||||
s.processMappings(ctx, msg.GetMapping())
|
s.processMappings(ctx, msg.GetMapping())
|
||||||
s.Logger.Debug("Processing mapping update completed")
|
s.Logger.Debug("Processing mapping update completed")
|
||||||
|
|
||||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
if !*initialSyncDone {
|
||||||
if s.healthChecker != nil {
|
for _, m := range msg.GetMapping() {
|
||||||
s.healthChecker.SetInitialSyncComplete()
|
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||||
|
}
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||||
|
snapshotIDs = nil
|
||||||
|
if s.healthChecker != nil {
|
||||||
|
s.healthChecker.SetInitialSyncComplete()
|
||||||
|
}
|
||||||
|
*initialSyncDone = true
|
||||||
|
s.Logger.Info("Initial mapping sync complete")
|
||||||
}
|
}
|
||||||
*initialSyncDone = true
|
|
||||||
s.Logger.Info("Initial mapping sync complete")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||||
|
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||||
|
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||||
|
s.portMu.RLock()
|
||||||
|
var stale []*proto.ProxyMapping
|
||||||
|
for svcID, mapping := range s.lastMappings {
|
||||||
|
if _, ok := snapshotIDs[svcID]; !ok {
|
||||||
|
stale = append(stale, mapping)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.portMu.RUnlock()
|
||||||
|
|
||||||
|
for _, mapping := range stale {
|
||||||
|
s.Logger.WithFields(log.Fields{
|
||||||
|
"service_id": mapping.GetId(),
|
||||||
|
"domain": mapping.GetDomain(),
|
||||||
|
}).Info("Removing stale mapping absent from snapshot")
|
||||||
|
s.removeMapping(ctx, mapping)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||||
for _, mapping := range mappings {
|
for _, mapping := range mappings {
|
||||||
s.Logger.WithFields(log.Fields{
|
s.Logger.WithFields(log.Fields{
|
||||||
|
|||||||
227
proxy/snapshot_reconcile_test.go
Normal file
227
proxy/snapshot_reconcile_test.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
|
||||||
|
// so we can verify it without triggering removeMapping (which requires full
|
||||||
|
// server wiring). This keeps the test focused on the detection algorithm.
|
||||||
|
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
|
||||||
|
var stale []types.ServiceID
|
||||||
|
for svcID := range lastMappings {
|
||||||
|
if _, ok := snapshotIDs[svcID]; !ok {
|
||||||
|
stale = append(stale, svcID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stale
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_PartialOverlap verifies that only services absent from
|
||||||
|
// the snapshot are flagged as stale.
|
||||||
|
func TestStaleDetection_PartialOverlap(t *testing.T) {
|
||||||
|
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||||
|
"svc-1": {Id: "svc-1"},
|
||||||
|
"svc-2": {Id: "svc-2"},
|
||||||
|
"svc-stale-a": {Id: "svc-stale-a"},
|
||||||
|
"svc-stale-b": {Id: "svc-stale-b"},
|
||||||
|
}
|
||||||
|
snapshot := map[types.ServiceID]struct{}{
|
||||||
|
"svc-1": {},
|
||||||
|
"svc-2": {},
|
||||||
|
"svc-3": {}, // new service, not in local
|
||||||
|
}
|
||||||
|
|
||||||
|
stale := collectStaleIDs(local, snapshot)
|
||||||
|
assert.Len(t, stale, 2)
|
||||||
|
staleSet := make(map[types.ServiceID]struct{})
|
||||||
|
for _, id := range stale {
|
||||||
|
staleSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
|
||||||
|
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
|
||||||
|
func TestStaleDetection_AllStale(t *testing.T) {
|
||||||
|
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||||
|
"svc-1": {Id: "svc-1"},
|
||||||
|
"svc-2": {Id: "svc-2"},
|
||||||
|
}
|
||||||
|
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
|
||||||
|
assert.Len(t, stale, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
|
||||||
|
func TestStaleDetection_NoneStale(t *testing.T) {
|
||||||
|
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||||
|
"svc-1": {Id: "svc-1"},
|
||||||
|
"svc-2": {Id: "svc-2"},
|
||||||
|
}
|
||||||
|
snapshot := map[types.ServiceID]struct{}{
|
||||||
|
"svc-1": {},
|
||||||
|
"svc-2": {},
|
||||||
|
}
|
||||||
|
stale := collectStaleIDs(local, snapshot)
|
||||||
|
assert.Empty(t, stale)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
|
||||||
|
func TestStaleDetection_EmptyLocal(t *testing.T) {
|
||||||
|
stale := collectStaleIDs(
|
||||||
|
map[types.ServiceID]*proto.ProxyMapping{},
|
||||||
|
map[types.ServiceID]struct{}{"svc-1": {}},
|
||||||
|
)
|
||||||
|
assert.Empty(t, stale)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
|
||||||
|
// local mappings are present in the snapshot (removeMapping is never called).
|
||||||
|
func TestReconcileSnapshot_NoStale(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
|
||||||
|
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
|
||||||
|
|
||||||
|
snapshotIDs := map[types.ServiceID]struct{}{
|
||||||
|
"svc-1": {},
|
||||||
|
"svc-2": {},
|
||||||
|
}
|
||||||
|
// This should not panic — no stale entries means removeMapping is never called.
|
||||||
|
s.reconcileSnapshot(context.Background(), snapshotIDs)
|
||||||
|
|
||||||
|
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
|
||||||
|
// no local mappings.
|
||||||
|
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
|
||||||
|
assert.Empty(t, s.lastMappings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- handleMappingStream tests for batched snapshot ID accumulation ---
|
||||||
|
|
||||||
|
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
|
||||||
|
// marked done only after the final InitialSyncComplete message, even when
|
||||||
|
// the snapshot arrives in multiple batches.
|
||||||
|
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||||
|
checker := health.NewChecker(nil, nil)
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
healthChecker: checker,
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := &mockMappingStream{
|
||||||
|
messages: []*proto.GetMappingUpdateResponse{
|
||||||
|
{}, // batch 1: no sync-complete
|
||||||
|
{}, // batch 2: no sync-complete
|
||||||
|
{InitialSyncComplete: true}, // batch 3: sync done
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
syncDone := false
|
||||||
|
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
|
||||||
|
// arriving after InitialSyncComplete do not trigger a second reconciliation.
|
||||||
|
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate state left over from a previous sync.
|
||||||
|
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
|
||||||
|
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
|
||||||
|
|
||||||
|
stream := &mockMappingStream{
|
||||||
|
messages: []*proto.GetMappingUpdateResponse{
|
||||||
|
{}, // post-sync empty message — must not reconcile
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
syncDone := true // sync already completed in a previous stream
|
||||||
|
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, s.lastMappings, 2,
|
||||||
|
"post-sync messages must not trigger reconciliation — all entries should survive")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
|
||||||
|
// stream closes before sync completes, no reconciliation occurs.
|
||||||
|
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||||
|
|
||||||
|
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||||
|
|
||||||
|
syncDone := false
|
||||||
|
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||||
|
|
||||||
|
_, hasStale := s.lastMappings["svc-stale"]
|
||||||
|
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockErrRecvStream returns an error on the second Recv to verify
|
||||||
|
// handleMappingStream returns without completing sync.
|
||||||
|
type mockErrRecvStream struct {
|
||||||
|
mockMappingStream
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||||
|
m.calls++
|
||||||
|
if m.calls == 1 {
|
||||||
|
return &proto.GetMappingUpdateResponse{}, nil
|
||||||
|
}
|
||||||
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||||
|
|
||||||
|
syncDone := false
|
||||||
|
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.False(t, syncDone)
|
||||||
|
|
||||||
|
_, hasStale := s.lastMappings["svc-stale"]
|
||||||
|
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user