diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 48873f640..39da2a79f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "time" @@ -46,9 +47,10 @@ type PKCEAuthorizationFlow struct { func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string - // find the first available redirect URL + excludedRanges := getSystemExcludedPortRanges() + for _, redirectURL := range config.RedirectURLs { - if !isRedirectURLPortUsed(redirectURL) { + if !isRedirectURLPortUsed(redirectURL, excludedRanges) { availableRedirectURL = redirectURL break } @@ -282,15 +284,22 @@ func createCodeChallenge(codeVerifier string) string { return base64.RawURLEncoding.EncodeToString(sha2[:]) } -// isRedirectURLPortUsed checks if the port used in the redirect URL is in use. -func isRedirectURLPortUsed(redirectURL string) bool { +// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows. +func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool { parsedURL, err := url.Parse(redirectURL) if err != nil { log.Errorf("failed to parse redirect URL: %v", err) return true } - addr := fmt.Sprintf(":%s", parsedURL.Port()) + port := parsedURL.Port() + + if isPortInExcludedRange(port, excludedRanges) { + log.Warnf("port %s is in Windows excluded port range, skipping", port) + return true + } + + addr := fmt.Sprintf(":%s", port) conn, err := net.DialTimeout("tcp", addr, 3*time.Second) if err != nil { return false @@ -304,6 +313,33 @@ func isRedirectURLPortUsed(redirectURL string) bool { return true } +// excludedPortRange represents a range of excluded ports. +type excludedPortRange struct { + start int + end int +} + +// isPortInExcludedRange checks if the given port is in any of the excluded ranges. +func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool { + if len(excludedRanges) == 0 { + return false + } + + portNum, err := strconv.Atoi(port) + if err != nil { + log.Debugf("invalid port number %s: %v", port, err) + return false + } + + for _, r := range excludedRanges { + if portNum >= r.start && portNum <= r.end { + return true + } + } + + return false +} + func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) { tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl) if err != nil { diff --git a/client/internal/auth/pkce_flow_other.go b/client/internal/auth/pkce_flow_other.go new file mode 100644 index 000000000..96df41539 --- /dev/null +++ b/client/internal/auth/pkce_flow_other.go @@ -0,0 +1,8 @@ +//go:build !windows + +package auth + +// getSystemExcludedPortRanges returns nil on non-Windows platforms. +func getSystemExcludedPortRanges() []excludedPortRange { + return nil +} diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b2347d12d..380a360e5 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -2,8 +2,11 @@ package auth import ( "context" + "fmt" + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal" @@ -69,3 +72,147 @@ func TestPromptLogin(t *testing.T) { }) } } + +func TestIsPortInExcludedRange(t *testing.T) { + tests := []struct { + name string + port string + excludedRanges []excludedPortRange + expectedBlocked bool + }{ + { + name: "Port in excluded range", + port: "8080", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at start of range", + port: "8000", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at end of range", + port: "8100", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port before range", + port: "7999", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Port after range", + port: "8101", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty excluded ranges", + port: "8080", + excludedRanges: []excludedPortRange{}, + expectedBlocked: false, + }, + { + name: "Nil excluded ranges", + port: "8080", + excludedRanges: nil, + expectedBlocked: false, + }, + { + name: "Multiple ranges - port in second range", + port: "9050", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: true, + }, + { + name: "Multiple ranges - port not in any range", + port: "8500", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: false, + }, + { + name: "Invalid port string", + port: "invalid", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty port string", + port: "", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isPortInExcludedRange(tt.port, tt.excludedRanges) + assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch") + }) + } +} + +func TestIsRedirectURLPortUsed(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + + usedPort := listener.Addr().(*net.TCPAddr).Port + + tests := []struct { + name string + redirectURL string + excludedRanges []excludedPortRange + expectedUsed bool + }{ + { + name: "Port in excluded range", + redirectURL: "http://127.0.0.1:8080/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + { + name: "Port actually in use", + redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort), + excludedRanges: nil, + expectedUsed: true, + }, + { + name: "Port not in use and not excluded", + redirectURL: "http://127.0.0.1:65432/", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Invalid URL without port", + redirectURL: "not-a-valid-url", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Port excluded even if not in use", + redirectURL: "http://127.0.0.1:8050/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges) + assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch") + }) + } +} diff --git a/client/internal/auth/pkce_flow_windows.go b/client/internal/auth/pkce_flow_windows.go new file mode 100644 index 000000000..cf3f8718f --- /dev/null +++ b/client/internal/auth/pkce_flow_windows.go @@ -0,0 +1,86 @@ +//go:build windows + +package auth + +import ( + "bufio" + "fmt" + "os/exec" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" +) + +// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh. +func getSystemExcludedPortRanges() []excludedPortRange { + ranges, err := getExcludedPortRangesFromNetsh() + if err != nil { + log.Debugf("failed to get Windows excluded port ranges: %v", err) + return nil + } + + return ranges +} + +// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command. +func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) { + cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("netsh command: %w", err) + } + + return parseExcludedPortRanges(string(output)) +} + +// parseExcludedPortRanges parses the output of the netsh command to extract port ranges. +func parseExcludedPortRanges(output string) ([]excludedPortRange, error) { + var ranges []excludedPortRange + scanner := bufio.NewScanner(strings.NewReader(output)) + + foundHeader := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") { + foundHeader = true + continue + } + + if !foundHeader { + continue + } + + if strings.Contains(line, "----------") { + continue + } + + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + startPort, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + + endPort, err := strconv.Atoi(fields[1]) + if err != nil { + continue + } + + ranges = append(ranges, excludedPortRange{start: startPort, end: endPort}) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan output: %w", err) + } + + return ranges, nil +} diff --git a/client/internal/auth/pkce_flow_windows_test.go b/client/internal/auth/pkce_flow_windows_test.go new file mode 100644 index 000000000..dd455b2fe --- /dev/null +++ b/client/internal/auth/pkce_flow_windows_test.go @@ -0,0 +1,116 @@ +//go:build windows + +package auth + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" +) + +func TestParseExcludedPortRanges(t *testing.T) { + tests := []struct { + name string + netshOutput string + expectedRanges []excludedPortRange + expectError bool + }{ + { + name: "Valid netsh output with multiple ranges", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 + +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 5357 5357 * + 50000 50059 * +`, + expectedRanges: []excludedPortRange{ + {start: 5357, end: 5357}, + {start: 50000, end: 50059}, + }, + expectError: false, + }, + { + name: "Empty output", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 +`, + expectedRanges: nil, + expectError: false, + }, + { + name: "Single range", + netshOutput: ` +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 8080 8090 +`, + expectedRanges: []excludedPortRange{ + {start: 8080, end: 8090}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseExcludedPortRanges(tt.netshOutput) + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedRanges, ranges) + } + }) + } +} + +func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) { + ranges := getSystemExcludedPortRanges() + t.Logf("Found %d excluded port ranges on this system", len(ranges)) + + listener1, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener1.Close() + }() + usedPort1 := listener1.Addr().(*net.TCPAddr).Port + + availablePort := 65432 + + config := internal.PKCEAuthProviderConfig{ + ClientID: "test-client-id", + Audience: "test-audience", + TokenEndpoint: "https://test-token-endpoint.com/token", + Scope: "openid email profile", + AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", + RedirectURLs: []string{ + fmt.Sprintf("http://127.0.0.1:%d/", usedPort1), + fmt.Sprintf("http://127.0.0.1:%d/", availablePort), + }, + UseIDToken: true, + } + + flow, err := NewPKCEAuthorizationFlow(config) + require.NoError(t, err) + require.NotNil(t, flow) + assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort), + "Should skip port in use and select available port") +}