mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay
This commit is contained in:
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -170,6 +170,7 @@ jobs:
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
@@ -309,6 +310,7 @@ jobs:
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
|
||||
@@ -171,6 +171,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
id: netbird_deb
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
@@ -184,6 +185,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
id: netbird_rpm
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
|
||||
@@ -17,8 +17,7 @@ ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="30"
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -23,8 +23,7 @@ ENV \
|
||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||
NB_DISABLE_DNS="true" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="30"
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
@@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
@@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
@@ -218,7 +218,7 @@ func runHealthCheck(cmd *cobra.Command) error {
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
isStartup := check == "startup"
|
||||
resp, err := getStatus(ctx, isStartup, isStartup)
|
||||
resp, err := getStatus(ctx, isStartup, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
set -eEuo pipefail
|
||||
|
||||
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="30"}
|
||||
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="30"}
|
||||
NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}"
|
||||
export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}"
|
||||
service_pids=()
|
||||
@@ -51,31 +50,10 @@ wait_for_daemon_startup() {
|
||||
exit 1
|
||||
}
|
||||
|
||||
login_if_needed() {
|
||||
local timeout="${1}"
|
||||
|
||||
if "${NETBIRD_BIN}" status --check ready 2>/dev/null; then
|
||||
info "already logged in, skipping 'netbird up'..."
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ "${timeout}" -eq 0 ]]; then
|
||||
info "logging in..."
|
||||
"${NETBIRD_BIN}" up
|
||||
return
|
||||
fi
|
||||
|
||||
local deadline=$((SECONDS + timeout))
|
||||
while [[ "${SECONDS}" -lt "${deadline}" ]]; do
|
||||
if "${NETBIRD_BIN}" status --check ready 2>/dev/null; then
|
||||
info "already logged in, skipping 'netbird up'..."
|
||||
return
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
info "logging in..."
|
||||
connect() {
|
||||
info "running 'netbird up'..."
|
||||
"${NETBIRD_BIN}" up
|
||||
return $?
|
||||
}
|
||||
|
||||
main() {
|
||||
@@ -85,7 +63,7 @@ main() {
|
||||
info "registered new service process 'netbird service run', currently running: ${service_pids[@]@Q}"
|
||||
|
||||
wait_for_daemon_startup "${NB_ENTRYPOINT_SERVICE_TIMEOUT}"
|
||||
login_if_needed "${NB_ENTRYPOINT_LOGIN_TIMEOUT}"
|
||||
connect
|
||||
|
||||
wait "${service_pids[@]}"
|
||||
}
|
||||
|
||||
3
go.mod
3
go.mod
@@ -30,10 +30,10 @@ require (
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
@@ -144,7 +144,6 @@ require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -34,8 +34,6 @@ github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSC
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
|
||||
@@ -262,7 +262,9 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
if opts == nil {
|
||||
opts = &api.ServiceTargetOptions{}
|
||||
}
|
||||
opts.ProxyProtocol = &target.ProxyProtocol
|
||||
if target.ProxyProtocol {
|
||||
opts.ProxyProtocol = &target.ProxyProtocol
|
||||
}
|
||||
st.Options = opts
|
||||
apiTargets = append(apiTargets, st)
|
||||
}
|
||||
@@ -848,7 +850,7 @@ func IsPortBasedProtocol(mode string) bool {
|
||||
}
|
||||
|
||||
const (
|
||||
maxCustomHeaders = 16
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
@@ -945,7 +947,6 @@ func containsCRLF(s string) bool {
|
||||
}
|
||||
|
||||
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
|
||||
seen := make(map[string]struct{})
|
||||
for i, h := range headers {
|
||||
if h == nil || !h.Enabled {
|
||||
continue
|
||||
@@ -966,10 +967,6 @@ func validateHeaderAuths(headers []*HeaderAuthConfig) error {
|
||||
if canonical == "Host" {
|
||||
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
|
||||
}
|
||||
if _, dup := seen[canonical]; dup {
|
||||
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
|
||||
}
|
||||
seen[canonical] = struct{}{}
|
||||
if len(h.Value) > maxHeaderValueLen {
|
||||
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
|
||||
}
|
||||
|
||||
@@ -935,3 +935,107 @@ func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
|
||||
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
|
||||
require.NoError(t, req.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_HeaderAuths(t *testing.T) {
|
||||
t.Run("single valid header", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "X-API-Key", Value: "secret"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("multiple headers same canonical name allowed", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Authorization", Value: "Bearer token-1"},
|
||||
{Enabled: true, Header: "Authorization", Value: "Bearer token-2"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("multiple headers different case same canonical allowed", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "x-api-key", Value: "key-1"},
|
||||
{Enabled: true, Header: "X-Api-Key", Value: "key-2"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("multiple different headers allowed", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Authorization", Value: "Bearer tok"},
|
||||
{Enabled: true, Header: "X-API-Key", Value: "key"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("empty header name rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "", Value: "val"},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "header name is required")
|
||||
})
|
||||
|
||||
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Connection", Value: "val"},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "hop-by-hop")
|
||||
})
|
||||
|
||||
t.Run("host header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "Host", Value: "val"},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Host header cannot be used")
|
||||
})
|
||||
|
||||
t.Run("disabled entries skipped", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: false, Header: "", Value: ""},
|
||||
{Enabled: true, Header: "X-Key", Value: "val"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("value too long rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Auth = AuthConfig{
|
||||
HeaderAuths: []*HeaderAuthConfig{
|
||||
{Enabled: true, Header: "X-Key", Value: strings.Repeat("a", maxHeaderValueLen+1)},
|
||||
},
|
||||
}
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -197,6 +197,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
case "jumpcloud":
|
||||
return NewJumpCloudManager(JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ApiUrl: config.ExtraConfig["ApiUrl"],
|
||||
}, appMetrics)
|
||||
case "pocketid":
|
||||
return NewPocketIdManager(PocketIdClientConfig{
|
||||
|
||||
@@ -1,24 +1,40 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
accept = "application/json"
|
||||
jumpCloudDefaultApiUrl = "https://console.jumpcloud.com"
|
||||
jumpCloudSearchPageSize = 100
|
||||
)
|
||||
|
||||
// jumpCloudUser represents a JumpCloud V1 API system user.
|
||||
type jumpCloudUser struct {
|
||||
ID string `json:"_id"`
|
||||
Email string `json:"email"`
|
||||
Firstname string `json:"firstname"`
|
||||
Middlename string `json:"middlename"`
|
||||
Lastname string `json:"lastname"`
|
||||
}
|
||||
|
||||
// jumpCloudUserList represents the response from the JumpCloud search endpoint.
|
||||
type jumpCloudUserList struct {
|
||||
Results []jumpCloudUser `json:"results"`
|
||||
TotalCount int `json:"totalCount"`
|
||||
}
|
||||
|
||||
// JumpCloudManager JumpCloud manager client instance.
|
||||
type JumpCloudManager struct {
|
||||
client *v1.APIClient
|
||||
apiBase string
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
@@ -29,6 +45,7 @@ type JumpCloudManager struct {
|
||||
// JumpCloudClientConfig JumpCloud manager client configurations.
|
||||
type JumpCloudClientConfig struct {
|
||||
APIToken string
|
||||
ApiUrl string
|
||||
}
|
||||
|
||||
// JumpCloudCredentials JumpCloud authentication information.
|
||||
@@ -55,7 +72,15 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
|
||||
}
|
||||
|
||||
client := v1.NewAPIClient(v1.NewConfiguration())
|
||||
apiBase := config.ApiUrl
|
||||
if apiBase == "" {
|
||||
apiBase = jumpCloudDefaultApiUrl
|
||||
}
|
||||
apiBase = strings.TrimSuffix(apiBase, "/")
|
||||
if !strings.HasSuffix(apiBase, "/api") {
|
||||
apiBase += "/api"
|
||||
}
|
||||
|
||||
credentials := &JumpCloudCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
@@ -64,7 +89,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
}
|
||||
|
||||
return &JumpCloudManager{
|
||||
client: client,
|
||||
apiBase: apiBase,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
@@ -78,37 +103,58 @@ func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
|
||||
Key: jm.apiToken,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
// doRequest executes an HTTP request against the JumpCloud V1 API.
|
||||
func (jm *JumpCloudManager) doRequest(ctx context.Context, method, path string, body io.Reader) ([]byte, error) {
|
||||
reqURL := jm.apiBase + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", jm.apiToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := jm.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
return nil, fmt.Errorf("JumpCloud API request %s %s failed with status %d", method, path, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := jm.doRequest(ctx, http.MethodGet, "/systemusers/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var user jumpCloudUser
|
||||
if err = jm.helper.Unmarshal(body, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
@@ -116,30 +162,20 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
func (jm *JumpCloudManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
allUsers, err := jm.searchAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
users := make([]*UserData, 0, len(allUsers))
|
||||
for _, user := range allUsers {
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
@@ -148,27 +184,18 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
func (jm *JumpCloudManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
allUsers, err := jm.searchAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range userList.Results {
|
||||
for _, user := range allUsers {
|
||||
userData := parseJumpCloudUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
@@ -176,6 +203,41 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// searchAllUsers paginates through all system users using limit/skip.
|
||||
func (jm *JumpCloudManager) searchAllUsers(ctx context.Context) ([]jumpCloudUser, error) {
|
||||
var allUsers []jumpCloudUser
|
||||
|
||||
for skip := 0; ; skip += jumpCloudSearchPageSize {
|
||||
searchReq := map[string]int{
|
||||
"limit": jumpCloudSearchPageSize,
|
||||
"skip": skip,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userList jumpCloudUserList
|
||||
if err = jm.helper.Unmarshal(body, &userList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allUsers = append(allUsers, userList.Results...)
|
||||
|
||||
if skip+len(userList.Results) >= userList.TotalCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
@@ -183,7 +245,7 @@ func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*U
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
@@ -191,25 +253,26 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*
|
||||
},
|
||||
}
|
||||
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
|
||||
payload, err := json.Marshal(searchFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0)
|
||||
var userList jumpCloudUserList
|
||||
if err = jm.helper.Unmarshal(body, &userList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0, len(userList.Results))
|
||||
for _, user := range userList.Results {
|
||||
usersData = append(usersData, parseJumpCloudUser(user))
|
||||
}
|
||||
@@ -224,20 +287,11 @@ func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
func (jm *JumpCloudManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := jm.doRequest(ctx, http.MethodDelete, "/systemusers/"+userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
@@ -247,11 +301,11 @@ func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
}
|
||||
|
||||
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
|
||||
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
|
||||
func parseJumpCloudUser(user jumpCloudUser) *UserData {
|
||||
names := []string{user.Firstname, user.Middlename, user.Lastname}
|
||||
return &UserData{
|
||||
Email: user.Email,
|
||||
Name: strings.Join(names, " "),
|
||||
ID: user.Id,
|
||||
ID: user.ID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -44,3 +51,212 @@ func TestNewJumpCloudManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJumpCloudGetUserDataByID(t *testing.T) {
|
||||
userResponse := jumpCloudUser{
|
||||
ID: "user123",
|
||||
Email: "test@example.com",
|
||||
Firstname: "John",
|
||||
Middlename: "",
|
||||
Lastname: "Doe",
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/systemusers/user123", r.URL.Path)
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(userResponse)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
userData, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{WTAccountID: "acc1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "user123", userData.ID)
|
||||
assert.Equal(t, "test@example.com", userData.Email)
|
||||
assert.Equal(t, "John Doe", userData.Name)
|
||||
assert.Equal(t, "acc1", userData.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestJumpCloudGetAccount(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/search/systemusers", r.URL.Path)
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
var reqBody map[string]any
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
|
||||
assert.Contains(t, reqBody, "limit")
|
||||
assert.Contains(t, reqBody, "skip")
|
||||
|
||||
resp := jumpCloudUserList{
|
||||
Results: []jumpCloudUser{
|
||||
{ID: "u1", Email: "a@test.com", Firstname: "Alice", Lastname: "Smith"},
|
||||
{ID: "u2", Email: "b@test.com", Firstname: "Bob", Lastname: "Jones"},
|
||||
},
|
||||
TotalCount: 2,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
users, err := manager.GetAccount(context.Background(), "testAccount")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 2)
|
||||
assert.Equal(t, "testAccount", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "testAccount", users[1].AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestJumpCloudGetAllAccounts(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := jumpCloudUserList{
|
||||
Results: []jumpCloudUser{
|
||||
{ID: "u1", Email: "a@test.com", Firstname: "Alice"},
|
||||
{ID: "u2", Email: "b@test.com", Firstname: "Bob"},
|
||||
},
|
||||
TotalCount: 2,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
indexedUsers, err := manager.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, indexedUsers[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestJumpCloudGetAllAccountsPagination(t *testing.T) {
|
||||
totalUsers := 250
|
||||
allUsers := make([]jumpCloudUser, totalUsers)
|
||||
for i := range allUsers {
|
||||
allUsers[i] = jumpCloudUser{
|
||||
ID: fmt.Sprintf("u%d", i),
|
||||
Email: fmt.Sprintf("user%d@test.com", i),
|
||||
Firstname: fmt.Sprintf("User%d", i),
|
||||
}
|
||||
}
|
||||
|
||||
requestCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var reqBody map[string]int
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody))
|
||||
|
||||
limit := reqBody["limit"]
|
||||
skip := reqBody["skip"]
|
||||
requestCount++
|
||||
|
||||
end := skip + limit
|
||||
if end > totalUsers {
|
||||
end = totalUsers
|
||||
}
|
||||
|
||||
resp := jumpCloudUserList{
|
||||
Results: allUsers[skip:end],
|
||||
TotalCount: totalUsers,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
indexedUsers, err := manager.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, indexedUsers[UnsetAccountID], totalUsers)
|
||||
assert.Equal(t, 3, requestCount, "should require 3 pages for 250 users at page size 100")
|
||||
}
|
||||
|
||||
func TestJumpCloudGetUserByEmail(t *testing.T) {
|
||||
searchResponse := jumpCloudUserList{
|
||||
Results: []jumpCloudUser{
|
||||
{ID: "u1", Email: "alice@test.com", Firstname: "Alice", Lastname: "Smith"},
|
||||
},
|
||||
TotalCount: 1,
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/search/systemusers", r.URL.Path)
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(body), "alice@test.com")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(searchResponse)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
users, err := manager.GetUserByEmail(context.Background(), "alice@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "alice@test.com", users[0].Email)
|
||||
}
|
||||
|
||||
func TestJumpCloudDeleteUser(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/systemusers/user123", r.URL.Path)
|
||||
assert.Equal(t, http.MethodDelete, r.Method)
|
||||
assert.Equal(t, "test-api-key", r.Header.Get("x-api-key"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"_id": "user123"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
err := manager.DeleteUser(context.Background(), "user123")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJumpCloudAPIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := newTestJumpCloudManager(t, server.URL)
|
||||
|
||||
_, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "401")
|
||||
}
|
||||
|
||||
func TestParseJumpCloudUser(t *testing.T) {
|
||||
user := jumpCloudUser{
|
||||
ID: "abc123",
|
||||
Email: "test@example.com",
|
||||
Firstname: "John",
|
||||
Middlename: "M",
|
||||
Lastname: "Doe",
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
assert.Equal(t, "abc123", userData.ID)
|
||||
assert.Equal(t, "test@example.com", userData.Email)
|
||||
assert.Equal(t, "John M Doe", userData.Name)
|
||||
}
|
||||
|
||||
func newTestJumpCloudManager(t *testing.T, apiBase string) *JumpCloudManager {
|
||||
t.Helper()
|
||||
return &JumpCloudManager{
|
||||
apiBase: apiBase,
|
||||
apiToken: "test-api-key",
|
||||
httpClient: http.DefaultClient,
|
||||
helper: JsonParser{},
|
||||
appMetrics: nil,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
if err != nil {
|
||||
newLabel = ""
|
||||
} else {
|
||||
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name)
|
||||
_, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, newLabel)
|
||||
if err == nil {
|
||||
newLabel = ""
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -2739,3 +2740,70 @@ func TestProcessPeerAddAuth(t *testing.T) {
|
||||
assert.Empty(t, config.GroupsToAdd)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
// Add first peer with hostname that produces DNS label "netbird1"
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add first peer")
|
||||
assert.Equal(t, "netbird1", peer1.DNSLabel)
|
||||
|
||||
// Add second peer with a different hostname
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
update := peer2.Copy()
|
||||
update.Name = "netbird1.demo.netbird.cloud"
|
||||
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "renaming peer should not fail with duplicate DNS label error")
|
||||
assert.Equal(t, "netbird1.demo.netbird.cloud", updated.Name)
|
||||
assert.NotEqual(t, "netbird1", updated.DNSLabel, "DNS label should not collide with existing peer")
|
||||
assert.Contains(t, updated.DNSLabel, "netbird1-", "DNS label should be IP-based fallback")
|
||||
}
|
||||
|
||||
func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "web-server", peer1.DNSLabel)
|
||||
|
||||
// Add second peer and rename it to a unique FQDN whose first label doesn't collide
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
update := peer2.Copy()
|
||||
update.Name = "api-server.example.com"
|
||||
updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
@@ -932,3 +932,71 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
assert.Equal(t, "header-user", capturedData2.GetUserID())
|
||||
assert.Equal(t, "header", capturedData2.GetAuthMethod())
|
||||
}
|
||||
|
||||
// TestProtect_HeaderAuth_MultipleValuesSameHeader verifies that the proxy
|
||||
// correctly handles multiple valid credentials for the same header name.
|
||||
// In production, the mgmt gRPC authenticateHeader iterates all configured
|
||||
// header auths and accepts if any hash matches (OR semantics). The proxy
|
||||
// creates one Header scheme per entry, but a single gRPC call checks all.
|
||||
func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Mock simulates mgmt behavior: accepts either token-a or token-b.
|
||||
accepted := map[string]bool{"Bearer token-a": true, "Bearer token-b": true}
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && accepted[ha.GetHeaderValue()] {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
|
||||
// Single Header scheme (as if one entry existed), but the mock checks both values.
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
var backendCalled bool
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
t.Run("first value accepted", func(t *testing.T) {
|
||||
backendCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.True(t, backendCalled, "first token should be accepted")
|
||||
})
|
||||
|
||||
t.Run("second value accepted", func(t *testing.T) {
|
||||
backendCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-b")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.True(t, backendCalled, "second token should be accepted")
|
||||
})
|
||||
|
||||
t.Run("unknown value rejected", func(t *testing.T) {
|
||||
backendCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-c")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData("")))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.False(t, backendCalled, "unknown token should be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,13 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
@@ -20,45 +19,55 @@ import (
|
||||
)
|
||||
|
||||
func Test_S3HandlerGetUploadURL(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" {
|
||||
t.Skip("Skipping test on non-Linux and CI environment due to docker dependency")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on Windows due to potential docker dependency")
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("Skipping test on non-Linux due to docker dependency")
|
||||
}
|
||||
|
||||
awsEndpoint := "http://127.0.0.1:4566"
|
||||
awsRegion := "us-east-1"
|
||||
|
||||
ctx := context.Background()
|
||||
containerRequest := testcontainers.ContainerRequest{
|
||||
Image: "localstack/localstack:s3-latest",
|
||||
ExposedPorts: []string{"4566:4566/tcp"},
|
||||
WaitingFor: wait.ForLog("Ready"),
|
||||
}
|
||||
|
||||
c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: containerRequest,
|
||||
Started: true,
|
||||
ContainerRequest: testcontainers.ContainerRequest{
|
||||
Image: "minio/minio:RELEASE.2025-04-22T22-12-26Z",
|
||||
ExposedPorts: []string{"9000/tcp"},
|
||||
Env: map[string]string{
|
||||
"MINIO_ROOT_USER": "minioadmin",
|
||||
"MINIO_ROOT_PASSWORD": "minioadmin",
|
||||
},
|
||||
Cmd: []string{"server", "/data"},
|
||||
WaitingFor: wait.ForHTTP("/minio/health/ready").WithPort("9000"),
|
||||
},
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func(c testcontainers.Container, ctx context.Context) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := c.Terminate(ctx); err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}(c, ctx)
|
||||
})
|
||||
|
||||
mappedPort, err := c.MappedPort(ctx, "9000")
|
||||
require.NoError(t, err)
|
||||
|
||||
hostIP, err := c.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port()
|
||||
|
||||
t.Setenv("AWS_REGION", awsRegion)
|
||||
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "test")
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "minioadmin")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin")
|
||||
t.Setenv("AWS_CONFIG_FILE", "")
|
||||
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")
|
||||
t.Setenv("AWS_PROFILE", "")
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(awsRegion),
|
||||
config.WithBaseEndpoint(awsEndpoint),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("minioadmin", "minioadmin", "")),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
o.UsePathStyle = true
|
||||
@@ -66,19 +75,16 @@ func Test_S3HandlerGetUploadURL(t *testing.T) {
|
||||
})
|
||||
|
||||
bucketName := "test"
|
||||
if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{
|
||||
_, err = client.CreateBucket(ctx, &s3.CreateBucketInput{
|
||||
Bucket: &bucketName,
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(list.Buckets), 1)
|
||||
assert.Equal(t, *list.Buckets[0].Name, bucketName)
|
||||
require.Len(t, list.Buckets, 1)
|
||||
require.Equal(t, bucketName, *list.Buckets[0].Name)
|
||||
|
||||
t.Setenv(bucketVar, bucketName)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user