mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
* Consolidate authentication logic - Moving auth functions from client/internal to client/internal/auth package - Creating unified auth.Auth client with NewAuth() constructor - Replacing direct auth function calls with auth client methods - Refactoring device flow and PKCE flow implementations - Updating iOS/Android/server code to use new auth client API * Refactor PKCE auth and login methods - Remove unnecessary internal package reference in PKCE flow test - Adjust context assignment placement in iOS and Android login methods
222 lines
5.6 KiB
Go
222 lines
5.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
|
)
|
|
|
|
func TestPromptLogin(t *testing.T) {
|
|
const (
|
|
promptLogin = "prompt=login"
|
|
maxAge0 = "max_age=0"
|
|
)
|
|
|
|
tt := []struct {
|
|
name string
|
|
loginFlag mgm.LoginFlag
|
|
disablePromptLogin bool
|
|
expectContains []string
|
|
}{
|
|
{
|
|
name: "Prompt login",
|
|
loginFlag: mgm.LoginFlagPromptLogin,
|
|
expectContains: []string{promptLogin},
|
|
},
|
|
{
|
|
name: "Max age 0",
|
|
loginFlag: mgm.LoginFlagMaxAge0,
|
|
expectContains: []string{maxAge0},
|
|
},
|
|
{
|
|
name: "Disable prompt login",
|
|
loginFlag: mgm.LoginFlagPromptLogin,
|
|
disablePromptLogin: true,
|
|
expectContains: []string{},
|
|
},
|
|
{
|
|
name: "None flag should not add parameters",
|
|
loginFlag: mgm.LoginFlagNone,
|
|
expectContains: []string{},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
config := PKCEAuthProviderConfig{
|
|
ClientID: "test-client-id",
|
|
Audience: "test-audience",
|
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
|
Scope: "openid email profile",
|
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
|
UseIDToken: true,
|
|
LoginFlag: tc.loginFlag,
|
|
DisablePromptLogin: tc.disablePromptLogin,
|
|
}
|
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
|
|
}
|
|
authInfo, err := pkce.RequestAuthInfo(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Failed to request auth info: %v", err)
|
|
}
|
|
|
|
for _, expected := range tc.expectContains {
|
|
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsPortInExcludedRange(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
port string
|
|
excludedRanges []excludedPortRange
|
|
expectedBlocked bool
|
|
}{
|
|
{
|
|
name: "Port in excluded range",
|
|
port: "8080",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: true,
|
|
},
|
|
{
|
|
name: "Port at start of range",
|
|
port: "8000",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: true,
|
|
},
|
|
{
|
|
name: "Port at end of range",
|
|
port: "8100",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: true,
|
|
},
|
|
{
|
|
name: "Port before range",
|
|
port: "7999",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: false,
|
|
},
|
|
{
|
|
name: "Port after range",
|
|
port: "8101",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: false,
|
|
},
|
|
{
|
|
name: "Empty excluded ranges",
|
|
port: "8080",
|
|
excludedRanges: []excludedPortRange{},
|
|
expectedBlocked: false,
|
|
},
|
|
{
|
|
name: "Nil excluded ranges",
|
|
port: "8080",
|
|
excludedRanges: nil,
|
|
expectedBlocked: false,
|
|
},
|
|
{
|
|
name: "Multiple ranges - port in second range",
|
|
port: "9050",
|
|
excludedRanges: []excludedPortRange{
|
|
{start: 8000, end: 8100},
|
|
{start: 9000, end: 9100},
|
|
},
|
|
expectedBlocked: true,
|
|
},
|
|
{
|
|
name: "Multiple ranges - port not in any range",
|
|
port: "8500",
|
|
excludedRanges: []excludedPortRange{
|
|
{start: 8000, end: 8100},
|
|
{start: 9000, end: 9100},
|
|
},
|
|
expectedBlocked: false,
|
|
},
|
|
{
|
|
name: "Invalid port string",
|
|
port: "invalid",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: false,
|
|
},
|
|
{
|
|
name: "Empty port string",
|
|
port: "",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedBlocked: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
|
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsRedirectURLPortUsed(t *testing.T) {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
_ = listener.Close()
|
|
}()
|
|
|
|
usedPort := listener.Addr().(*net.TCPAddr).Port
|
|
|
|
tests := []struct {
|
|
name string
|
|
redirectURL string
|
|
excludedRanges []excludedPortRange
|
|
expectedUsed bool
|
|
}{
|
|
{
|
|
name: "Port in excluded range",
|
|
redirectURL: "http://127.0.0.1:8080/",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedUsed: true,
|
|
},
|
|
{
|
|
name: "Port actually in use",
|
|
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
|
excludedRanges: nil,
|
|
expectedUsed: true,
|
|
},
|
|
{
|
|
name: "Port not in use and not excluded",
|
|
redirectURL: "http://127.0.0.1:65432/",
|
|
excludedRanges: nil,
|
|
expectedUsed: false,
|
|
},
|
|
{
|
|
name: "Invalid URL without port",
|
|
redirectURL: "not-a-valid-url",
|
|
excludedRanges: nil,
|
|
expectedUsed: false,
|
|
},
|
|
{
|
|
name: "Port excluded even if not in use",
|
|
redirectURL: "http://127.0.0.1:8050/",
|
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
|
expectedUsed: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
|
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
|
})
|
|
}
|
|
}
|