Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay

This commit is contained in:
Viktor Liu
2026-03-25 09:54:47 +01:00
18 changed files with 659 additions and 163 deletions

View File

@@ -170,6 +170,7 @@ jobs:
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
@@ -309,6 +310,7 @@ jobs:
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |

View File

@@ -171,6 +171,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_deb
bindir: /usr/bin
builds:
@@ -184,6 +185,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_rpm
bindir: /usr/bin
builds:

View File

@@ -17,8 +17,7 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="30"
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,8 +23,7 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="30"
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
@@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
}
cmd.Println("netbird down")
time.Sleep(1 * time.Second)
@@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
}
cmd.Println("netbird up")
time.Sleep(3 * time.Second)
@@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
}
cmd.Println("netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
} else {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())

View File

@@ -218,7 +218,7 @@ func runHealthCheck(cmd *cobra.Command) error {
ctx := internal.CtxInitState(cmd.Context())
isStartup := check == "startup"
resp, err := getStatus(ctx, isStartup, isStartup)
resp, err := getStatus(ctx, isStartup, false)
if err != nil {
return err
}

View File

@@ -2,7 +2,6 @@
set -eEuo pipefail
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="30"}
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="30"}
NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}"
export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}"
service_pids=()
@@ -51,31 +50,10 @@ wait_for_daemon_startup() {
exit 1
}
login_if_needed() {
local timeout="${1}"
if "${NETBIRD_BIN}" status --check ready 2>/dev/null; then
info "already logged in, skipping 'netbird up'..."
return
fi
if [[ "${timeout}" -eq 0 ]]; then
info "logging in..."
"${NETBIRD_BIN}" up
return
fi
local deadline=$((SECONDS + timeout))
while [[ "${SECONDS}" -lt "${deadline}" ]]; do
if "${NETBIRD_BIN}" status --check ready 2>/dev/null; then
info "already logged in, skipping 'netbird up'..."
return
fi
sleep 1
done
info "logging in..."
connect() {
info "running 'netbird up'..."
"${NETBIRD_BIN}" up
return $?
}
main() {
@@ -85,7 +63,7 @@ main() {
info "registered new service process 'netbird service run', currently running: ${service_pids[@]@Q}"
wait_for_daemon_startup "${NB_ENTRYPOINT_SERVICE_TIMEOUT}"
login_if_needed "${NB_ENTRYPOINT_LOGIN_TIMEOUT}"
connect
wait "${service_pids[@]}"
}

3
go.mod
View File

@@ -30,10 +30,10 @@ require (
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.14
github.com/aws/aws-sdk-go-v2/credentials v1.17.67
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
@@ -144,7 +144,6 @@ require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect

2
go.sum
View File

@@ -34,8 +34,6 @@ github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSC
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=

View File

@@ -262,7 +262,9 @@ func (s *Service) ToAPIResponse() *api.Service {
if opts == nil {
opts = &api.ServiceTargetOptions{}
}
opts.ProxyProtocol = &target.ProxyProtocol
if target.ProxyProtocol {
opts.ProxyProtocol = &target.ProxyProtocol
}
st.Options = opts
apiTargets = append(apiTargets, st)
}
@@ -848,7 +850,7 @@ func IsPortBasedProtocol(mode string) bool {
}
const (
maxCustomHeaders = 16
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
@@ -945,7 +947,6 @@ func containsCRLF(s string) bool {
}
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
seen := make(map[string]struct{})
for i, h := range headers {
if h == nil || !h.Enabled {
continue
@@ -966,10 +967,6 @@ func validateHeaderAuths(headers []*HeaderAuthConfig) error {
if canonical == "Host" {
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
}
if _, dup := seen[canonical]; dup {
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
}
seen[canonical] = struct{}{}
if len(h.Value) > maxHeaderValueLen {
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
}

View File

@@ -935,3 +935,107 @@ func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
require.NoError(t, req.Validate())
}
func TestValidate_HeaderAuths(t *testing.T) {
t.Run("single valid header", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "X-API-Key", Value: "secret"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("multiple headers same canonical name allowed", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Authorization", Value: "Bearer token-1"},
{Enabled: true, Header: "Authorization", Value: "Bearer token-2"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("multiple headers different case same canonical allowed", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "x-api-key", Value: "key-1"},
{Enabled: true, Header: "X-Api-Key", Value: "key-2"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("multiple different headers allowed", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Authorization", Value: "Bearer tok"},
{Enabled: true, Header: "X-API-Key", Value: "key"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("empty header name rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "", Value: "val"},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "header name is required")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Connection", Value: "val"},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "hop-by-hop")
})
t.Run("host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "Host", Value: "val"},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "Host header cannot be used")
})
t.Run("disabled entries skipped", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: false, Header: "", Value: ""},
{Enabled: true, Header: "X-Key", Value: "val"},
},
}
require.NoError(t, rp.Validate())
})
t.Run("value too long rejected", func(t *testing.T) {
rp := validProxy()
rp.Auth = AuthConfig{
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "X-Key", Value: strings.Repeat("a", maxHeaderValueLen+1)},
},
}
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum length")
})
}

View File

@@ -197,6 +197,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
case "jumpcloud":
return NewJumpCloudManager(JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ApiUrl: config.ExtraConfig["ApiUrl"],
}, appMetrics)
case "pocketid":
return NewPocketIdManager(PocketIdClientConfig{

View File

@@ -1,24 +1,40 @@
package idp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
v1 "github.com/TheJumpCloud/jcapi-go/v1"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
contentType = "application/json"
accept = "application/json"
jumpCloudDefaultApiUrl = "https://console.jumpcloud.com"
jumpCloudSearchPageSize = 100
)
// jumpCloudUser represents a JumpCloud V1 API system user.
type jumpCloudUser struct {
ID string `json:"_id"`
Email string `json:"email"`
Firstname string `json:"firstname"`
Middlename string `json:"middlename"`
Lastname string `json:"lastname"`
}
// jumpCloudUserList represents the response from the JumpCloud search endpoint.
type jumpCloudUserList struct {
Results []jumpCloudUser `json:"results"`
TotalCount int `json:"totalCount"`
}
// JumpCloudManager JumpCloud manager client instance.
type JumpCloudManager struct {
client *v1.APIClient
apiBase string
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
@@ -29,6 +45,7 @@ type JumpCloudManager struct {
// JumpCloudClientConfig JumpCloud manager client configurations.
type JumpCloudClientConfig struct {
APIToken string
ApiUrl string
}
// JumpCloudCredentials JumpCloud authentication information.
@@ -55,7 +72,15 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
}
client := v1.NewAPIClient(v1.NewConfiguration())
apiBase := config.ApiUrl
if apiBase == "" {
apiBase = jumpCloudDefaultApiUrl
}
apiBase = strings.TrimSuffix(apiBase, "/")
if !strings.HasSuffix(apiBase, "/api") {
apiBase += "/api"
}
credentials := &JumpCloudCredentials{
clientConfig: config,
httpClient: httpClient,
@@ -64,7 +89,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
}
return &JumpCloudManager{
client: client,
apiBase: apiBase,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
@@ -78,37 +103,58 @@ func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error
return JWTToken{}, nil
}
func (jm *JumpCloudManager) authenticationContext() context.Context {
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
Key: jm.apiToken,
})
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
// doRequest executes an HTTP request against the JumpCloud V1 API.
func (jm *JumpCloudManager) doRequest(ctx context.Context, method, path string, body io.Reader) ([]byte, error) {
reqURL := jm.apiBase + path
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Set("x-api-key", jm.apiToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := jm.httpClient.Do(req)
if err != nil {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
return nil, fmt.Errorf("JumpCloud API request %s %s failed with status %d", method, path, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := jm.doRequest(ctx, http.MethodGet, "/systemusers/"+userID, nil)
if err != nil {
return nil, err
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var user jumpCloudUser
if err = jm.helper.Unmarshal(body, &user); err != nil {
return nil, err
}
userData := parseJumpCloudUser(user)
userData.AppMetadata = appMetadata
@@ -116,30 +162,20 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap
}
// GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
func (jm *JumpCloudManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
allUsers, err := jm.searchAllUsers(ctx)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
users := make([]*UserData, 0, len(allUsers))
for _, user := range allUsers {
userData := parseJumpCloudUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
@@ -148,27 +184,18 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
func (jm *JumpCloudManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
allUsers, err := jm.searchAllUsers(ctx)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
for _, user := range allUsers {
userData := parseJumpCloudUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
@@ -176,6 +203,41 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use
return indexedUsers, nil
}
// searchAllUsers paginates through all system users using limit/skip.
func (jm *JumpCloudManager) searchAllUsers(ctx context.Context) ([]jumpCloudUser, error) {
var allUsers []jumpCloudUser
for skip := 0; ; skip += jumpCloudSearchPageSize {
searchReq := map[string]int{
"limit": jumpCloudSearchPageSize,
"skip": skip,
}
payload, err := json.Marshal(searchReq)
if err != nil {
return nil, err
}
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
if err != nil {
return nil, err
}
var userList jumpCloudUserList
if err = jm.helper.Unmarshal(body, &userList); err != nil {
return nil, err
}
allUsers = append(allUsers, userList.Results...)
if skip+len(userList.Results) >= userList.TotalCount {
break
}
}
return allUsers, nil
}
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
@@ -183,7 +245,7 @@ func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*U
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
func (jm *JumpCloudManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{
"filter": []string{email},
@@ -191,25 +253,26 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*
},
}
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
payload, err := json.Marshal(searchFilter)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
if err != nil {
return nil, err
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
usersData := make([]*UserData, 0)
var userList jumpCloudUserList
if err = jm.helper.Unmarshal(body, &userList); err != nil {
return nil, err
}
usersData := make([]*UserData, 0, len(userList.Results))
for _, user := range userList.Results {
usersData = append(usersData, parseJumpCloudUser(user))
}
@@ -224,20 +287,11 @@ func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
}
// DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
func (jm *JumpCloudManager) DeleteUser(ctx context.Context, userID string) error {
_, err := jm.doRequest(ctx, http.MethodDelete, "/systemusers/"+userID, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountDeleteUser()
@@ -247,11 +301,11 @@ func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
}
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
func parseJumpCloudUser(user jumpCloudUser) *UserData {
names := []string{user.Firstname, user.Middlename, user.Lastname}
return &UserData{
Email: user.Email,
Name: strings.Join(names, " "),
ID: user.Id,
ID: user.ID,
}
}

View File

@@ -1,8 +1,15 @@
package idp
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -44,3 +51,212 @@ func TestNewJumpCloudManager(t *testing.T) {
})
}
}
func TestJumpCloudGetUserDataByID(t *testing.T) {
userResponse := jumpCloudUser{
ID: "user123",
Email: "test@example.com",
Firstname: "John",
Middlename: "",
Lastname: "Doe",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/systemusers/user123", r.URL.Path)
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(userResponse)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
userData, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{WTAccountID: "acc1"})
require.NoError(t, err)
assert.Equal(t, "user123", userData.ID)
assert.Equal(t, "test@example.com", userData.Email)
assert.Equal(t, "John Doe", userData.Name)
assert.Equal(t, "acc1", userData.AppMetadata.WTAccountID)
}
func TestJumpCloudGetAccount(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/search/systemusers", r.URL.Path)
assert.Equal(t, http.MethodPost, r.Method)
var reqBody map[string]any
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
assert.Contains(t, reqBody, "limit")
assert.Contains(t, reqBody, "skip")
resp := jumpCloudUserList{
Results: []jumpCloudUser{
{ID: "u1", Email: "a@test.com", Firstname: "Alice", Lastname: "Smith"},
{ID: "u2", Email: "b@test.com", Firstname: "Bob", Lastname: "Jones"},
},
TotalCount: 2,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
users, err := manager.GetAccount(context.Background(), "testAccount")
require.NoError(t, err)
assert.Len(t, users, 2)
assert.Equal(t, "testAccount", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "testAccount", users[1].AppMetadata.WTAccountID)
}
func TestJumpCloudGetAllAccounts(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := jumpCloudUserList{
Results: []jumpCloudUser{
{ID: "u1", Email: "a@test.com", Firstname: "Alice"},
{ID: "u2", Email: "b@test.com", Firstname: "Bob"},
},
TotalCount: 2,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
indexedUsers, err := manager.GetAllAccounts(context.Background())
require.NoError(t, err)
assert.Len(t, indexedUsers[UnsetAccountID], 2)
}
func TestJumpCloudGetAllAccountsPagination(t *testing.T) {
totalUsers := 250
allUsers := make([]jumpCloudUser, totalUsers)
for i := range allUsers {
allUsers[i] = jumpCloudUser{
ID: fmt.Sprintf("u%d", i),
Email: fmt.Sprintf("user%d@test.com", i),
Firstname: fmt.Sprintf("User%d", i),
}
}
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqBody map[string]int
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
limit := reqBody["limit"]
skip := reqBody["skip"]
requestCount++
end := skip + limit
if end > totalUsers {
end = totalUsers
}
resp := jumpCloudUserList{
Results: allUsers[skip:end],
TotalCount: totalUsers,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
indexedUsers, err := manager.GetAllAccounts(context.Background())
require.NoError(t, err)
assert.Len(t, indexedUsers[UnsetAccountID], totalUsers)
assert.Equal(t, 3, requestCount, "should require 3 pages for 250 users at page size 100")
}
func TestJumpCloudGetUserByEmail(t *testing.T) {
searchResponse := jumpCloudUserList{
Results: []jumpCloudUser{
{ID: "u1", Email: "alice@test.com", Firstname: "Alice", Lastname: "Smith"},
},
TotalCount: 1,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/search/systemusers", r.URL.Path)
assert.Equal(t, http.MethodPost, r.Method)
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Contains(t, string(body), "alice@test.com")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(searchResponse)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
users, err := manager.GetUserByEmail(context.Background(), "alice@test.com")
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, "alice@test.com", users[0].Email)
}
func TestJumpCloudDeleteUser(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/systemusers/user123", r.URL.Path)
assert.Equal(t, http.MethodDelete, r.Method)
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"_id": "user123"})
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
err := manager.DeleteUser(context.Background(), "user123")
require.NoError(t, err)
}
func TestJumpCloudAPIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer server.Close()
manager := newTestJumpCloudManager(t, server.URL)
_, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{})
require.Error(t, err)
assert.Contains(t, err.Error(), "401")
}
func TestParseJumpCloudUser(t *testing.T) {
user := jumpCloudUser{
ID: "abc123",
Email: "test@example.com",
Firstname: "John",
Middlename: "M",
Lastname: "Doe",
}
userData := parseJumpCloudUser(user)
assert.Equal(t, "abc123", userData.ID)
assert.Equal(t, "test@example.com", userData.Email)
assert.Equal(t, "John M Doe", userData.Name)
}
func newTestJumpCloudManager(t *testing.T, apiBase string) *JumpCloudManager {
t.Helper()
return &JumpCloudManager{
apiBase: apiBase,
apiToken: "test-api-key",
httpClient: http.DefaultClient,
helper: JsonParser{},
appMetrics: nil,
}
}

View File

@@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if err != nil {
newLabel = ""
} else {
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name)
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, newLabel)
if err == nil {
newLabel = ""
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/util"
@@ -2739,3 +2740,70 @@ func TestProcessPeerAddAuth(t *testing.T) {
assert.Empty(t, config.GroupsToAdd)
})
}
func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
// Add first peer with hostname that produces DNS label "netbird1"
key1, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"},
}, false)
require.NoError(t, err, "unable to add first peer")
assert.Equal(t, "netbird1", peer1.DNSLabel)
// Add second peer with a different hostname
key2, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"},
}, false)
require.NoError(t, err)
update := peer2.Copy()
update.Name = "netbird1.demo.netbird.cloud"
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "renaming peer should not fail with duplicate DNS label error")
assert.Equal(t, "netbird1.demo.netbird.cloud", updated.Name)
assert.NotEqual(t, "netbird1", updated.DNSLabel, "DNS label should not collide with existing peer")
assert.Contains(t, updated.DNSLabel, "netbird1-", "DNS label should be IP-based fallback")
}
func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"},
}, false)
require.NoError(t, err)
assert.Equal(t, "web-server", peer1.DNSLabel)
// Add second peer and rename it to a unique FQDN whose first label doesn't collide
key2, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"},
}, false)
require.NoError(t, err)
update := peer2.Copy()
update.Name = "api-server.example.com"
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "renaming to unique FQDN should succeed")
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
}

View File

@@ -932,3 +932,71 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
assert.Equal(t, "header-user", capturedData2.GetUserID())
assert.Equal(t, "header", capturedData2.GetAuthMethod())
}
// TestProtect_HeaderAuth_MultipleValuesSameHeader verifies that the proxy
// correctly handles multiple valid credentials for the same header name.
// In production, the mgmt gRPC authenticateHeader iterates all configured
// header auths and accepts if any hash matches (OR semantics). The proxy
// creates one Header scheme per entry, but a single gRPC call checks all.
func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
// Mock simulates mgmt behavior: accepts either token-a or token-b.
accepted := map[string]bool{"Bearer token-a": true, "Bearer token-b": true}
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && accepted[ha.GetHeaderValue()] {
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
require.NoError(t, err)
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
return &proto.AuthenticateResponse{Success: false}, nil
}}
// Single Header scheme (as if one entry existed), but the mock checks both values.
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
}))
t.Run("first value accepted", func(t *testing.T) {
backendCalled = false
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("Authorization", "Bearer token-a")
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.True(t, backendCalled, "first token should be accepted")
})
t.Run("second value accepted", func(t *testing.T) {
backendCalled = false
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("Authorization", "Bearer token-b")
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.True(t, backendCalled, "second token should be accepted")
})
t.Run("unknown value rejected", func(t *testing.T) {
backendCalled = false
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("Authorization", "Bearer token-c")
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.False(t, backendCalled, "unknown token should be rejected")
})
}

View File

@@ -5,13 +5,12 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"runtime"
"testing"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
@@ -20,45 +19,55 @@ import (
)
func Test_S3HandlerGetUploadURL(t *testing.T) {
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" {
t.Skip("Skipping test on non-Linux and CI environment due to docker dependency")
}
if runtime.GOOS == "windows" {
t.Skip("Skipping test on Windows due to potential docker dependency")
if runtime.GOOS != "linux" {
t.Skip("Skipping test on non-Linux due to docker dependency")
}
awsEndpoint := "http://127.0.0.1:4566"
awsRegion := "us-east-1"
ctx := context.Background()
containerRequest := testcontainers.ContainerRequest{
Image: "localstack/localstack:s3-latest",
ExposedPorts: []string{"4566:4566/tcp"},
WaitingFor: wait.ForLog("Ready"),
}
c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: containerRequest,
Started: true,
ContainerRequest: testcontainers.ContainerRequest{
Image: "minio/minio:RELEASE.2025-04-22T22-12-26Z",
ExposedPorts: []string{"9000/tcp"},
Env: map[string]string{
"MINIO_ROOT_USER": "minioadmin",
"MINIO_ROOT_PASSWORD": "minioadmin",
},
Cmd: []string{"server", "/data"},
WaitingFor: wait.ForHTTP("/minio/health/ready").WithPort("9000"),
},
Started: true,
})
if err != nil {
t.Error(err)
}
defer func(c testcontainers.Container, ctx context.Context) {
require.NoError(t, err)
t.Cleanup(func() {
if err := c.Terminate(ctx); err != nil {
t.Log(err)
}
}(c, ctx)
})
mappedPort, err := c.MappedPort(ctx, "9000")
require.NoError(t, err)
hostIP, err := c.Host(ctx)
require.NoError(t, err)
awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port()
t.Setenv("AWS_REGION", awsRegion)
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)
t.Setenv("AWS_ACCESS_KEY_ID", "test")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test")
t.Setenv("AWS_ACCESS_KEY_ID", "minioadmin")
t.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin")
t.Setenv("AWS_CONFIG_FILE", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
t.Setenv("AWS_PROFILE", "")
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint))
if err != nil {
t.Error(err)
}
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(awsRegion),
config.WithBaseEndpoint(awsEndpoint),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("minioadmin", "minioadmin", "")),
)
require.NoError(t, err)
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
@@ -66,19 +75,16 @@ func Test_S3HandlerGetUploadURL(t *testing.T) {
})
bucketName := "test"
if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{
_, err = client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: &bucketName,
}); err != nil {
t.Error(err)
}
})
require.NoError(t, err)
list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
assert.Equal(t, len(list.Buckets), 1)
assert.Equal(t, *list.Buckets[0].Name, bucketName)
require.Len(t, list.Buckets, 1)
require.Equal(t, bucketName, *list.Buckets[0].Name)
t.Setenv(bucketVar, bucketName)