diff --git a/client/internal/dns/dnsfw/config.go b/client/internal/dns/dnsfw/config.go index 372a34686..f11913cdd 100644 --- a/client/internal/dns/dnsfw/config.go +++ b/client/internal/dns/dnsfw/config.go @@ -25,12 +25,6 @@ const ( // processes: 53 (plain DNS) and 853 (DNS-over-TLS). var defaultBlockedPorts = []uint16{53, 853} -// strictMode reports whether strict mode is enabled via env. -func strictMode() bool { - v, _ := strconv.ParseBool(os.Getenv(EnvStrict)) - return v -} - // blockedPorts returns the effective port list, honoring env overrides. // A nil return means the firewall should not be installed. func blockedPorts() []uint16 { diff --git a/client/internal/dns/dnsfw/dnsfw_windows.go b/client/internal/dns/dnsfw/dnsfw_windows.go index e1ae412a7..5ee95fbc1 100644 --- a/client/internal/dns/dnsfw/dnsfw_windows.go +++ b/client/internal/dns/dnsfw/dnsfw_windows.go @@ -6,6 +6,7 @@ import ( "fmt" "net/netip" "os" + "strconv" "sync" "unsafe" @@ -113,6 +114,12 @@ func New() Manager { return &windowsManager{} } +// strictMode reports whether strict mode is enabled via env. +func strictMode() bool { + v, _ := strconv.ParseBool(os.Getenv(EnvStrict)) + return v +} + // luidFromGUID converts a Windows interface GUID string to its LUID. func luidFromGUID(ifaceGUID string) (luid uint64, err error) { defer func() { diff --git a/client/internal/dns/dnsfw/dnsfw_windows_test.go b/client/internal/dns/dnsfw/dnsfw_windows_test.go new file mode 100644 index 000000000..0f7c623bf --- /dev/null +++ b/client/internal/dns/dnsfw/dnsfw_windows_test.go @@ -0,0 +1,72 @@ +//go:build windows + +package dnsfw + +import ( + "net/netip" + "os" + "testing" +) + +func TestStrictMode(t *testing.T) { + tests := []struct { + name string + val string + set bool + want bool + }{ + {name: "unset", want: false}, + {name: "true", val: "true", set: true, want: true}, + {name: "1", val: "1", set: true, want: true}, + {name: "false", val: "false", set: true, want: false}, + {name: "invalid is false", val: "garbage", set: true, want: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(EnvStrict, tc.val) + if !tc.set { + os.Unsetenv(EnvStrict) + } + if got := strictMode(); got != tc.want { + t.Fatalf("strictMode() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestWindowsManagerDisableIdempotent(t *testing.T) { + m := &windowsManager{} + if err := m.Disable(); err != nil { + t.Fatalf("first Disable on fresh manager: %v", err) + } + if err := m.Disable(); err != nil { + t.Fatalf("second Disable on fresh manager: %v", err) + } + if m.session != 0 { + t.Fatalf("session should remain zero, got %d", m.session) + } +} + +func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) { + t.Setenv(EnvDisable, "true") + + m := &windowsManager{} + if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil { + t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err) + } + if m.session != 0 { + t.Fatalf("session must remain zero when env disables firewall, got %d", m.session) + } +} + +func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) { + t.Setenv(EnvPorts, "") + + m := &windowsManager{} + if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil { + t.Fatalf("Enable should be a no-op when ports list is empty: %v", err) + } + if m.session != 0 { + t.Fatalf("session must remain zero when ports list is empty, got %d", m.session) + } +}