mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
42 Commits
feature/di
...
trigger-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfeb60fbb5 | ||
|
|
ea41cf2d2c | ||
|
|
0c990ab662 | ||
|
|
101c813e98 | ||
|
|
5333e55a81 | ||
|
|
81c11df103 | ||
|
|
f74bc48d16 | ||
|
|
0169e4540f | ||
|
|
cead3f38ee | ||
|
|
b55262d4a2 | ||
|
|
2248ff392f | ||
|
|
06966da012 | ||
|
|
d4f7df271a | ||
|
|
5299549eb6 | ||
|
|
7d791620a6 | ||
|
|
44ab454a13 | ||
|
|
11f50d6c38 | ||
|
|
05af39a69b | ||
|
|
074df56c3d | ||
|
|
2381e216e4 | ||
|
|
ded04b7627 | ||
|
|
67211010f7 | ||
|
|
c61568ceb4 | ||
|
|
737d6061bf | ||
|
|
ee3a67d2d8 | ||
|
|
1a32e4c223 | ||
|
|
269d5d1cba | ||
|
|
a1de2b8a98 | ||
|
|
d0221a3e72 | ||
|
|
8da23daae3 | ||
|
|
f86022eace | ||
|
|
ee54827f94 | ||
|
|
e908dea702 | ||
|
|
030650a905 | ||
|
|
e01998815e | ||
|
|
07e4a5a23c | ||
|
|
b3a2992a10 | ||
|
|
202fa47f2b | ||
|
|
4888021ba6 | ||
|
|
a0b0b664b6 | ||
|
|
50da5074e7 | ||
|
|
58daa674ef |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
@@ -3,15 +3,7 @@ package android
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
go urlOpener.OnLoginSuccess()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
|||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
}
|
}
|
||||||
@@ -221,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
cpuProfilingStarted := false
|
||||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
|
||||||
|
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
|
||||||
|
} else {
|
||||||
|
cpuProfilingStarted = true
|
||||||
|
defer func() {
|
||||||
|
if cpuProfilingStarted {
|
||||||
|
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||||
|
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||||
return waitErr
|
return waitErr
|
||||||
}
|
}
|
||||||
cmd.Println("\nDuration completed")
|
cmd.Println("\nDuration completed")
|
||||||
|
|
||||||
|
if cpuProfilingStarted {
|
||||||
|
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||||
|
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||||
|
} else {
|
||||||
|
cpuProfilingStarted = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("Creating debug bundle...")
|
cmd.Println("Creating debug bundle...")
|
||||||
|
|
||||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: statusOutput,
|
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
}
|
}
|
||||||
@@ -302,24 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
|
||||||
var statusOutputString string
|
|
||||||
statusResp, err := getStatus(cmd.Context(), true)
|
|
||||||
if err != nil {
|
|
||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
|
||||||
} else {
|
|
||||||
pm := profilemanager.NewProfileManager()
|
|
||||||
var profName string
|
|
||||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
|
||||||
profName = activeProf.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
|
||||||
statusOutputString = overview.FullDetailSummary()
|
|
||||||
}
|
|
||||||
return statusOutputString
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@@ -378,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
|||||||
InternalConfig: config,
|
InternalConfig: config,
|
||||||
StatusRecorder: recorder,
|
StatusRecorder: recorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogFile: logFilePath,
|
LogPath: logFilePath,
|
||||||
|
CPUProfile: nil,
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
IncludeSystemInfo: true,
|
IncludeSystemInfo: true,
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||||
|
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin := false
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
err, isAuthError := authClient.Login(ctx, "", "")
|
||||||
err := internal.Login(ctx, config, "", "")
|
if isAuthError {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
needsLogin = true
|
||||||
needsLogin = true
|
} else if err != nil {
|
||||||
return nil
|
return fmt.Errorf("login check failed: %v", err)
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastError error
|
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
lastError = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
if lastError != nil {
|
|
||||||
return fmt.Errorf("login failed: %v", lastError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
|
||||||
defer c()
|
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
profName = activeProf.Name
|
profName = activeProf.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
@@ -97,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
@@ -115,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -124,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil)
|
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -68,6 +69,8 @@ type Options struct {
|
|||||||
StatePath string
|
StatePath string
|
||||||
// DisableClientRoutes disables the client routes
|
// DisableClientRoutes disables the client routes
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
|
// BlockInbound blocks all inbound connections from peers
|
||||||
|
BlockInbound bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -136,6 +139,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
PreSharedKey: &opts.PreSharedKey,
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
|
BlockInbound: &opts.BlockInbound,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -176,7 +180,13 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +200,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
run := make(chan struct{})
|
run := make(chan struct{})
|
||||||
clientErr := make(chan error, 1)
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run); err != nil {
|
if err := client.Run(run, ""); err != nil {
|
||||||
clientErr <- err
|
clientErr <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("acl manager init: %w", err)
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.initNoTrackChain(); err != nil {
|
||||||
|
return fmt.Errorf("init notrack chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := m.cleanupNoTrackChain(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.aclMgr.Reset(); err != nil {
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
}
|
}
|
||||||
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
chainNameRaw = "NETBIRD-RAW"
|
||||||
|
chainOUTPUT = "OUTPUT"
|
||||||
|
tableRaw = "raw"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
|
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||||
|
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||||
|
//
|
||||||
|
// Traffic flows that need NOTRACK:
|
||||||
|
//
|
||||||
|
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||||
|
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||||
|
// Matched by: sport=wgPort
|
||||||
|
//
|
||||||
|
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||||
|
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 3. Ingress: Packets to WireGuard
|
||||||
|
// dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||||
|
// dst=127.0.0.1:proxyPort
|
||||||
|
// Matched by: dport=proxyPort
|
||||||
|
//
|
||||||
|
// Rules are cleaned up when the firewall manager is closed.
|
||||||
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||||
|
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||||
|
|
||||||
|
// Egress rules: match outgoing loopback UDP packets
|
||||||
|
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||||
|
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||||
|
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ingress rules: match incoming loopback UDP packets
|
||||||
|
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||||
|
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||||
|
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initNoTrackChain() error {
|
||||||
|
if err := m.cleanupNoTrackChain(); err != nil {
|
||||||
|
log.Debugf("cleanup notrack chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||||
|
return fmt.Errorf("create chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||||
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
|
log.Debugf("delete orphan chain: %v", delErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add output jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||||
|
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||||
|
log.Debugf("delete output jump rule: %v", delErr)
|
||||||
|
}
|
||||||
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
|
log.Debugf("delete orphan chain: %v", delErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) cleanupNoTrackChain() error {
|
||||||
|
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check chain exists: %w", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||||
|
return fmt.Errorf("remove output jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||||
|
return fmt.Errorf("clear and delete chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -168,6 +168,10 @@ type Manager interface {
|
|||||||
|
|
||||||
// RemoveInboundDNAT removes inbound DNAT rule
|
// RemoveInboundDNAT removes inbound DNAT rule
|
||||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||||
|
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||||
|
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
@@ -48,8 +49,10 @@ type Manager struct {
|
|||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
router *router
|
router *router
|
||||||
aclManager *AclManager
|
aclManager *AclManager
|
||||||
|
notrackOutputChain *nftables.Chain
|
||||||
|
notrackPreroutingChain *nftables.Chain
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("acl manager init: %w", err)
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.initNoTrackChains(workTable); err != nil {
|
||||||
|
return fmt.Errorf("init notrack chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
// We only need to record minimal interface state for potential recreation.
|
// We only need to record minimal interface state for potential recreation.
|
||||||
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclManager.Flush()
|
if err := m.aclManager.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshNoTrackChains(); err != nil {
|
||||||
|
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
// AddDNATRule adds a DNAT rule
|
||||||
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
chainNameRawOutput = "netbird-raw-out"
|
||||||
|
chainNameRawPrerouting = "netbird-raw-pre"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
|
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||||
|
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||||
|
//
|
||||||
|
// Traffic flows that need NOTRACK:
|
||||||
|
//
|
||||||
|
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||||
|
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||||
|
// Matched by: sport=wgPort
|
||||||
|
//
|
||||||
|
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||||
|
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 3. Ingress: Packets to WireGuard
|
||||||
|
// dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||||
|
// dst=127.0.0.1:proxyPort
|
||||||
|
// Matched by: dport=proxyPort
|
||||||
|
//
|
||||||
|
// Rules are cleaned up when the firewall manager is closed.
|
||||||
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||||
|
return fmt.Errorf("notrack chains not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||||
|
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||||
|
loopback := []byte{127, 0, 0, 1}
|
||||||
|
|
||||||
|
// Egress rules: match outgoing loopback UDP packets
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackOutputChain.Table,
|
||||||
|
Chain: m.notrackOutputChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackOutputChain.Table,
|
||||||
|
Chain: m.notrackOutputChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Ingress rules: match incoming loopback UDP packets
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackPreroutingChain.Table,
|
||||||
|
Chain: m.notrackPreroutingChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackPreroutingChain.Table,
|
||||||
|
Chain: m.notrackPreroutingChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush notrack rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||||
|
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRawOutput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookOutput,
|
||||||
|
Priority: nftables.ChainPriorityRaw,
|
||||||
|
})
|
||||||
|
|
||||||
|
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRawPrerouting,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityRaw,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush chain creation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshNoTrackChains() error {
|
||||||
|
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := getTableName()
|
||||||
|
for _, c := range chains {
|
||||||
|
if c.Table.Name != tableName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch c.Name {
|
||||||
|
case chainNameRawOutput:
|
||||||
|
m.notrackOutputChain = c
|
||||||
|
case chainNameRawPrerouting:
|
||||||
|
m.notrackPreroutingChain = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -570,6 +570,14 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
|||||||
169
client/iface/bind/dual_stack_conn.go
Normal file
169
client/iface/bind/dual_stack_conn.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNoIPv4Conn = errors.New("no IPv4 connection available")
|
||||||
|
errNoIPv6Conn = errors.New("no IPv6 connection available")
|
||||||
|
errInvalidAddr = errors.New("invalid address type")
|
||||||
|
)
|
||||||
|
|
||||||
|
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
|
||||||
|
// to the appropriate connection based on the destination address.
|
||||||
|
// ReadFrom is not used in the hot path - ICEBind receives packets via
|
||||||
|
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
|
||||||
|
type DualStackPacketConn struct {
|
||||||
|
ipv4Conn net.PacketConn
|
||||||
|
ipv6Conn net.PacketConn
|
||||||
|
|
||||||
|
readFromWarn sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDualStackPacketConn creates a new dual-stack packet connection.
|
||||||
|
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
|
||||||
|
return &DualStackPacketConn{
|
||||||
|
ipv4Conn: ipv4Conn,
|
||||||
|
ipv6Conn: ipv6Conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFrom reads from the available connection (preferring IPv4).
|
||||||
|
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
|
||||||
|
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
|
||||||
|
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
|
||||||
|
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
|
||||||
|
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
|
||||||
|
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
|
d.readFromWarn.Do(func() {
|
||||||
|
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
|
||||||
|
})
|
||||||
|
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
return d.ipv4Conn.ReadFrom(b)
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
return d.ipv6Conn.ReadFrom(b)
|
||||||
|
}
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes to the appropriate connection based on the address type.
|
||||||
|
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
return 0, &net.OpError{
|
||||||
|
Op: "write",
|
||||||
|
Net: "udp",
|
||||||
|
Addr: addr,
|
||||||
|
Err: errInvalidAddr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if udpAddr.IP.To4() == nil {
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
return d.ipv6Conn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
return 0, &net.OpError{
|
||||||
|
Op: "write",
|
||||||
|
Net: "udp6",
|
||||||
|
Addr: addr,
|
||||||
|
Err: errNoIPv6Conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
return d.ipv4Conn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
return 0, &net.OpError{
|
||||||
|
Op: "write",
|
||||||
|
Net: "udp4",
|
||||||
|
Addr: addr,
|
||||||
|
Err: errNoIPv4Conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes both connections.
|
||||||
|
func (d *DualStackPacketConn) Close() error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local address of the IPv4 connection if available,
|
||||||
|
// otherwise the IPv6 connection.
|
||||||
|
func (d *DualStackPacketConn) LocalAddr() net.Addr {
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
return d.ipv4Conn.LocalAddr()
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
return d.ipv6Conn.LocalAddr()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the deadline for both connections.
|
||||||
|
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.SetDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.SetDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline for both connections.
|
||||||
|
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the write deadline for both connections.
|
||||||
|
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
|
||||||
|
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
|
||||||
|
payload = make([]byte, 1200)
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
|
||||||
|
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = conn.WriteTo(payload, ipv4Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
|
||||||
|
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn, nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
|
||||||
|
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(nil, conn)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
|
||||||
|
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn4.Close()
|
||||||
|
|
||||||
|
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn6.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn4, conn6)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
|
||||||
|
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn4.Close()
|
||||||
|
|
||||||
|
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn6.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn4, conn6)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
|
||||||
|
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn4.Close()
|
||||||
|
|
||||||
|
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn6.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn4, conn6)
|
||||||
|
addrs := []net.Addr{ipv4Addr, ipv6Addr}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, addrs[i&1])
|
||||||
|
}
|
||||||
|
}
|
||||||
191
client/iface/bind/dual_stack_conn_test.go
Normal file
191
client/iface/bind/dual_stack_conn_test.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
|
||||||
|
ipv4Conn := &mockPacketConn{network: "udp4"}
|
||||||
|
ipv6Conn := &mockPacketConn{network: "udp6"}
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr *net.UDPAddr
|
||||||
|
wantSocket string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 address",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||||
|
wantSocket: "udp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||||
|
wantSocket: "udp6",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4-mapped IPv6 goes to IPv4",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
|
||||||
|
wantSocket: "udp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 loopback",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
||||||
|
wantSocket: "udp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
|
||||||
|
wantSocket: "udp6",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ipv4Conn.writeCount = 0
|
||||||
|
ipv6Conn.writeCount = 0
|
||||||
|
|
||||||
|
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 4, n)
|
||||||
|
|
||||||
|
if tt.wantSocket == "udp4" {
|
||||||
|
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
|
||||||
|
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
|
||||||
|
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
|
||||||
|
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
|
||||||
|
|
||||||
|
// IPv4 works
|
||||||
|
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// IPv6 fails
|
||||||
|
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no IPv6 connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
|
||||||
|
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
|
||||||
|
|
||||||
|
// IPv6 works
|
||||||
|
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// IPv4 fails
|
||||||
|
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no IPv4 connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
|
||||||
|
// only reads from one socket (IPv4 preferred). This is fine because the actual
|
||||||
|
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
|
||||||
|
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
|
||||||
|
ipv4Conn := &mockPacketConn{
|
||||||
|
network: "udp4",
|
||||||
|
readData: []byte("from ipv4"),
|
||||||
|
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||||
|
}
|
||||||
|
ipv6Conn := &mockPacketConn{
|
||||||
|
network: "udp6",
|
||||||
|
readData: []byte("from ipv6"),
|
||||||
|
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||||
|
}
|
||||||
|
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||||
|
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
n, addr, err := dualStack.ReadFrom(buf)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
// reads from IPv4 (preferred) - this is expected behavior
|
||||||
|
assert.Equal(t, "from ipv4", string(buf[:n]))
|
||||||
|
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
|
||||||
|
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
|
||||||
|
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ipv4 net.PacketConn
|
||||||
|
ipv6 net.PacketConn
|
||||||
|
wantAddr net.Addr
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "both available returns IPv4",
|
||||||
|
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||||
|
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||||
|
wantAddr: ipv4Addr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 only",
|
||||||
|
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||||
|
ipv6: nil,
|
||||||
|
wantAddr: ipv4Addr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 only",
|
||||||
|
ipv4: nil,
|
||||||
|
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||||
|
wantAddr: ipv6Addr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "neither returns nil",
|
||||||
|
ipv4: nil,
|
||||||
|
ipv6: nil,
|
||||||
|
wantAddr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
|
||||||
|
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mock
|
||||||
|
|
||||||
|
type mockPacketConn struct {
|
||||||
|
network string
|
||||||
|
writeCount int
|
||||||
|
readData []byte
|
||||||
|
readAddr net.Addr
|
||||||
|
localAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
|
if m.readData != nil {
|
||||||
|
return copy(b, m.readData), m.readAddr, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
m.writeCount++
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPacketConn) Close() error { return nil }
|
||||||
|
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
|
||||||
|
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
@@ -28,22 +27,7 @@ type receiverCreator struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
|
||||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
|
||||||
}
|
|
||||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
|
||||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
|
||||||
buf := bufs[0]
|
|
||||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = size
|
|
||||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
|
||||||
eps[0] = stdEp
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICEBind is a bind implementation with two main features:
|
// ICEBind is a bind implementation with two main features:
|
||||||
@@ -73,6 +57,8 @@ type ICEBind struct {
|
|||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
|
ipv4Conn *net.UDPConn
|
||||||
|
ipv6Conn *net.UDPConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||||
@@ -118,6 +104,12 @@ func (s *ICEBind) Close() error {
|
|||||||
|
|
||||||
close(s.closedChan)
|
close(s.closedChan)
|
||||||
|
|
||||||
|
s.muUDPMux.Lock()
|
||||||
|
s.ipv4Conn = nil
|
||||||
|
s.ipv6Conn = nil
|
||||||
|
s.udpMux = nil
|
||||||
|
s.muUDPMux.Unlock()
|
||||||
|
|
||||||
return s.StdNetBind.Close()
|
return s.StdNetBind.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
defer s.muUDPMux.Unlock()
|
defer s.muUDPMux.Unlock()
|
||||||
|
|
||||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
// Detect IPv4 vs IPv6 from connection's local address
|
||||||
udpmux.UniversalUDPMuxParams{
|
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
|
||||||
UDPConn: nbnet.WrapPacketConn(conn),
|
s.ipv4Conn = conn
|
||||||
Net: s.transportNet,
|
} else {
|
||||||
FilterFn: s.filterFn,
|
s.ipv6Conn = conn
|
||||||
WGAddress: s.address,
|
}
|
||||||
MTU: s.mtu,
|
s.createOrUpdateMux()
|
||||||
},
|
|
||||||
)
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
msgs := getMessages(msgsPool)
|
msgs := getMessages(msgsPool)
|
||||||
for i := range bufs {
|
for i := range bufs {
|
||||||
@@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
}
|
}
|
||||||
defer putMessages(msgs, msgsPool)
|
defer putMessages(msgs, msgsPool)
|
||||||
|
|
||||||
var numMsgs int
|
var numMsgs int
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
if rxOffload {
|
if rxOffload {
|
||||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||||
//nolint
|
//nolint:staticcheck
|
||||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
}
|
}
|
||||||
numMsgs = 1
|
numMsgs = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
for i := 0; i < numMsgs; i++ {
|
||||||
msg := &(*msgs)[i]
|
msg := &(*msgs)[i]
|
||||||
|
|
||||||
// todo: handle err
|
// todo: handle err
|
||||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
|
||||||
if ok {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
sizes[i] = msg.N
|
sizes[i] = msg.N
|
||||||
@@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createOrUpdateMux creates or updates the UDP mux with the available connections.
|
||||||
|
// Must be called with muUDPMux held.
|
||||||
|
func (s *ICEBind) createOrUpdateMux() {
|
||||||
|
var muxConn net.PacketConn
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case s.ipv4Conn != nil && s.ipv6Conn != nil:
|
||||||
|
muxConn = NewDualStackPacketConn(
|
||||||
|
nbnet.WrapPacketConn(s.ipv4Conn),
|
||||||
|
nbnet.WrapPacketConn(s.ipv6Conn),
|
||||||
|
)
|
||||||
|
case s.ipv4Conn != nil:
|
||||||
|
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
|
||||||
|
case s.ipv6Conn != nil:
|
||||||
|
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't close the old mux - it doesn't own the underlying connections.
|
||||||
|
// The sockets are managed by WireGuard's StdNetBind, not by us.
|
||||||
|
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||||
|
udpmux.UniversalUDPMuxParams{
|
||||||
|
UDPConn: muxConn,
|
||||||
|
Net: s.transportNet,
|
||||||
|
FilterFn: s.filterFn,
|
||||||
|
WGAddress: s.address,
|
||||||
|
MTU: s.mtu,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||||
for i := range buffers {
|
for i := range buffers {
|
||||||
if !stun.IsMessage(buffers[i]) {
|
if !stun.IsMessage(buffers[i]) {
|
||||||
@@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
|
|||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
s.muUDPMux.Lock()
|
||||||
if muxErr != nil {
|
mux := s.udpMux
|
||||||
log.Warnf("failed to handle STUN packet")
|
s.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
if mux != nil {
|
||||||
|
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
|
||||||
|
log.Warnf("failed to handle STUN packet: %v", muxErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buffers[i] = []byte{}
|
buffers[i] = []byte{}
|
||||||
|
|||||||
324
client/iface/bind/ice_bind_test.go
Normal file
324
client/iface/bind/ice_bind_test.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3/stdnet"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
|
||||||
|
iceBind := setupICEBind(t)
|
||||||
|
|
||||||
|
ipv4Conn, ipv6Conn := createDualStackConns(t)
|
||||||
|
defer ipv4Conn.Close()
|
||||||
|
defer ipv6Conn.Close()
|
||||||
|
|
||||||
|
rc := receiverCreator{iceBind}
|
||||||
|
pool := createMsgPool()
|
||||||
|
|
||||||
|
// Simulate wireguard-go calling CreateReceiverFn for IPv4
|
||||||
|
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
|
||||||
|
require.NotNil(t, ipv4RecvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
|
||||||
|
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
|
||||||
|
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
// Simulate wireguard-go calling CreateReceiverFn for IPv6
|
||||||
|
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
|
||||||
|
require.NotNil(t, ipv6RecvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
|
||||||
|
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
|
||||||
|
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
mux, err := iceBind.GetICEMux()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
|
||||||
|
iceBind := setupICEBind(t)
|
||||||
|
|
||||||
|
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ipv4Conn.Close()
|
||||||
|
|
||||||
|
rc := receiverCreator{iceBind}
|
||||||
|
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
|
||||||
|
require.NotNil(t, recvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.NotNil(t, iceBind.ipv4Conn)
|
||||||
|
assert.Nil(t, iceBind.ipv6Conn)
|
||||||
|
assert.NotNil(t, iceBind.udpMux)
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
mux, err := iceBind.GetICEMux()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
|
||||||
|
iceBind := setupICEBind(t)
|
||||||
|
|
||||||
|
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer ipv6Conn.Close()
|
||||||
|
|
||||||
|
rc := receiverCreator{iceBind}
|
||||||
|
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
|
||||||
|
require.NotNil(t, recvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.Nil(t, iceBind.ipv4Conn)
|
||||||
|
assert.NotNil(t, iceBind.ipv6Conn)
|
||||||
|
assert.NotNil(t, iceBind.udpMux)
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
mux, err := iceBind.GetICEMux()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
|
||||||
|
// with peers on different address families through the same DualStackPacketConn.
|
||||||
|
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
|
||||||
|
// two "remote peers" listening on different address families
|
||||||
|
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Peer.Close()
|
||||||
|
|
||||||
|
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer ipv6Peer.Close()
|
||||||
|
|
||||||
|
// our local dual-stack connection
|
||||||
|
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Local.Close()
|
||||||
|
|
||||||
|
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||||
|
defer ipv6Local.Close()
|
||||||
|
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||||
|
|
||||||
|
// send to both peers
|
||||||
|
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify IPv4 peer got its packet from the IPv4 socket
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
n, addr, err := ipv4Peer.ReadFrom(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "to-ipv4", string(buf[:n]))
|
||||||
|
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||||
|
|
||||||
|
// verify IPv6 peer got its packet from the IPv6 socket
|
||||||
|
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
n, addr, err = ipv6Peer.ReadFrom(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "to-ipv6", string(buf[:n]))
|
||||||
|
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
|
||||||
|
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
|
||||||
|
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
|
||||||
|
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||||
|
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Peer.Close()
|
||||||
|
|
||||||
|
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer ipv6Peer.Close()
|
||||||
|
|
||||||
|
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Local.Close()
|
||||||
|
|
||||||
|
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||||
|
defer ipv6Local.Close()
|
||||||
|
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||||
|
|
||||||
|
const packetsPerFamily = 500
|
||||||
|
|
||||||
|
ipv4Received := make(chan string, packetsPerFamily)
|
||||||
|
ipv6Received := make(chan string, packetsPerFamily)
|
||||||
|
|
||||||
|
startGate := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
n, _, err := ipv4Peer.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv4Received <- string(buf[:n])
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
n, _, err := ipv6Peer.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv6Received <- string(buf[:n])
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-startGate
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-startGate
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
close(startGate)
|
||||||
|
|
||||||
|
time.AfterFunc(5*time.Second, func() {
|
||||||
|
_ = ipv4Peer.SetReadDeadline(time.Now())
|
||||||
|
_ = ipv6Peer.SetReadDeadline(time.Now())
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(ipv4Received)
|
||||||
|
close(ipv6Received)
|
||||||
|
|
||||||
|
ipv4Count := 0
|
||||||
|
for pkt := range ipv4Received {
|
||||||
|
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
|
||||||
|
ipv4Count++
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv6Count := 0
|
||||||
|
for pkt := range ipv6Received {
|
||||||
|
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
|
||||||
|
ipv6Count++
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||||
|
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
wantIPv4 bool
|
||||||
|
}{
|
||||||
|
{"IPv4 any", "udp4", "0.0.0.0:0", true},
|
||||||
|
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
|
||||||
|
{"IPv6 any", "udp6", "[::]:0", false},
|
||||||
|
{"IPv6 loopback", "udp6", "[::1]:0", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP(tt.network, addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("%s not available: %v", tt.network, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
isIPv4 := localAddr.IP.To4() != nil
|
||||||
|
assert.Equal(t, tt.wantIPv4, isIPv4)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
func setupICEBind(t *testing.T) *ICEBind {
|
||||||
|
t.Helper()
|
||||||
|
transportNet, err := stdnet.NewNet()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
address := wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||||
|
}
|
||||||
|
return NewICEBind(transportNet, nil, address, 1280)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||||
|
t.Helper()
|
||||||
|
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
ipv4Conn.Close()
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
return ipv4Conn, ipv6Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMsgPool() *sync.Pool {
|
||||||
|
return &sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
msgs := make([]ipv6.Message, 1)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, 40)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
|
||||||
|
t.Helper()
|
||||||
|
udpAddr, err := net.ResolveUDPAddr(network, addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn, err := net.ListenUDP(network, udpAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return conn
|
||||||
|
}
|
||||||
@@ -3,8 +3,22 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
|
||||||
|
// This is a shared helper used by both kernel and userspace configurers.
|
||||||
|
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
|
||||||
|
return wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{{
|
||||||
|
PublicKey: peerKey,
|
||||||
|
PresharedKey: &psk,
|
||||||
|
UpdateOnly: updateOnly,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||||
ipNets := make([]net.IPNet, len(prefixes))
|
ipNets := make([]net.IPNet, len(prefixes))
|
||||||
for i, prefix := range prefixes {
|
for i, prefix := range prefixes {
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
var zeroKey wgtypes.Key
|
|
||||||
|
|
||||||
type KernelConfigurer struct {
|
type KernelConfigurer struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
}
|
}
|
||||||
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPresharedKey sets the preshared key for a peer.
|
||||||
|
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
|
||||||
|
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
|
||||||
|
return c.configure(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
|||||||
TxBytes: p.TransmitBytes,
|
TxBytes: p.TransmitBytes,
|
||||||
RxBytes: p.ReceiveBytes,
|
RxBytes: p.ReceiveBytes,
|
||||||
LastHandshake: p.LastHandshakeTime,
|
LastHandshake: p.LastHandshakeTime,
|
||||||
PresharedKey: p.PresharedKey != zeroKey,
|
PresharedKey: [32]byte(p.PresharedKey),
|
||||||
}
|
}
|
||||||
if p.Endpoint != nil {
|
if p.Endpoint != nil {
|
||||||
peer.Endpoint = *p.Endpoint
|
peer.Endpoint = *p.Endpoint
|
||||||
|
|||||||
@@ -22,17 +22,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
privateKey = "private_key"
|
privateKey = "private_key"
|
||||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
ipcKeyTxBytes = "tx_bytes"
|
||||||
ipcKeyTxBytes = "tx_bytes"
|
ipcKeyRxBytes = "rx_bytes"
|
||||||
ipcKeyRxBytes = "rx_bytes"
|
allowedIP = "allowed_ip"
|
||||||
allowedIP = "allowed_ip"
|
endpoint = "endpoint"
|
||||||
endpoint = "endpoint"
|
fwmark = "fwmark"
|
||||||
fwmark = "fwmark"
|
listenPort = "listen_port"
|
||||||
listenPort = "listen_port"
|
publicKey = "public_key"
|
||||||
publicKey = "public_key"
|
presharedKey = "preshared_key"
|
||||||
presharedKey = "preshared_key"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||||
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPresharedKey sets the preshared key for a peer.
|
||||||
|
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
|
||||||
|
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
|
||||||
|
return c.device.IpcSet(toWgUserspaceString(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
hexKey := hex.EncodeToString(p.PublicKey[:])
|
hexKey := hex.EncodeToString(p.PublicKey[:])
|
||||||
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
|
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
|
||||||
|
|
||||||
|
if p.Remove {
|
||||||
|
sb.WriteString("remove=true\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.UpdateOnly {
|
||||||
|
sb.WriteString("update_only=true\n")
|
||||||
|
}
|
||||||
|
|
||||||
if p.PresharedKey != nil {
|
if p.PresharedKey != nil {
|
||||||
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
||||||
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Remove {
|
|
||||||
sb.WriteString("remove=true")
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.ReplaceAllowedIPs {
|
|
||||||
sb.WriteString("replace_allowed_ips=true\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, aip := range p.AllowedIPs {
|
|
||||||
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.Endpoint != nil {
|
if p.Endpoint != nil {
|
||||||
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
|
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
|
||||||
}
|
}
|
||||||
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
if p.PersistentKeepaliveInterval != nil {
|
if p.PersistentKeepaliveInterval != nil {
|
||||||
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
|
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.ReplaceAllowedIPs {
|
||||||
|
sb.WriteString("replace_allowed_ips=true\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, aip := range p.AllowedIPs {
|
||||||
|
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
@@ -543,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
host, portStr, err := net.SplitHostPort(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse endpoint: %v", err)
|
log.Errorf("failed to parse endpoint: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
|
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||||
currentPeer.PresharedKey = true
|
if pskKey, err := hexToWireguardKey(val); err == nil {
|
||||||
|
currentPeer.PresharedKey = [32]byte(pskKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ type Peer struct {
|
|||||||
TxBytes int64
|
TxBytes int64
|
||||||
RxBytes int64
|
RxBytes int64
|
||||||
LastHandshake time.Time
|
LastHandshake time.Time
|
||||||
PresharedKey bool
|
PresharedKey [32]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type Stats struct {
|
type Stats struct {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type WGConfigurer interface {
|
|||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
|
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||||
Close()
|
Close()
|
||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ func ValidateMTU(mtu uint16) error {
|
|||||||
|
|
||||||
type wgProxyFactory interface {
|
type wgProxyFactory interface {
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
|
GetProxyPort() uint16
|
||||||
Free() error
|
Free() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +81,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
|||||||
return w.wgProxyFactory.GetProxy()
|
return w.wgProxyFactory.GetProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||||
|
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||||
|
func (w *WGIface) GetProxyPort() uint16 {
|
||||||
|
return w.wgProxyFactory.GetProxyPort()
|
||||||
|
}
|
||||||
|
|
||||||
// GetBind returns the EndpointManager userspace bind mode.
|
// GetBind returns the EndpointManager userspace bind mode.
|
||||||
func (w *WGIface) GetBind() device.EndpointManager {
|
func (w *WGIface) GetBind() device.EndpointManager {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
@@ -297,6 +304,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
|||||||
return w.configurer.FullStats()
|
return w.configurer.FullStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPresharedKey sets or updates the preshared key for a peer.
|
||||||
|
// If updateOnly is true, only updates existing peer; if false, creates or updates.
|
||||||
|
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
if w.configurer == nil {
|
||||||
|
return ErrIfaceNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WGIface) waitUntilRemoved() error {
|
func (w *WGIface) waitUntilRemoved() error {
|
||||||
maxWaitTime := 5 * time.Second
|
maxWaitTime := 5 * time.Second
|
||||||
timeout := time.NewTimer(maxWaitTime)
|
timeout := time.NewTimer(maxWaitTime)
|
||||||
|
|||||||
@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
ep, err := addrToEndpoint(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to start package redirection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
p.wgCurrentUsed = ep
|
||||||
|
|
||||||
p.pausedCond.Signal()
|
p.pausedCond.Signal()
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
|
||||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
|
||||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
|
||||||
return &bind.Endpoint{AddrPort: addrPort}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ProxyBind) CloseConn() error {
|
func (p *ProxyBind) CloseConn() error {
|
||||||
if p.cancel == nil {
|
if p.cancel == nil {
|
||||||
return fmt.Errorf("proxy not started")
|
return fmt.Errorf("proxy not started")
|
||||||
@@ -212,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
|||||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||||
return &netipAddr, nil
|
return &netipAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||||
|
if addr == nil {
|
||||||
|
return nil, fmt.Errorf("invalid address")
|
||||||
|
}
|
||||||
|
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||||
|
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -26,13 +24,10 @@ const (
|
|||||||
loopbackAddr = "127.0.0.1"
|
loopbackAddr = "127.0.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
|
||||||
)
|
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
|
proxyPort int
|
||||||
mtu uint16
|
mtu uint16
|
||||||
|
|
||||||
ebpfManager ebpfMgr.Manager
|
ebpfManager ebpfMgr.Manager
|
||||||
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
|
|||||||
turnConnMutex sync.Mutex
|
turnConnMutex sync.Mutex
|
||||||
|
|
||||||
lastUsedPort uint16
|
lastUsedPort uint16
|
||||||
rawConn net.PacketConn
|
rawConnIPv4 net.PacketConn
|
||||||
|
rawConnIPv6 net.PacketConn
|
||||||
conn transport.UDPConn
|
conn transport.UDPConn
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
|||||||
// Listen load ebpf program and listen the proxy
|
// Listen load ebpf program and listen the proxy
|
||||||
func (p *WGEBPFProxy) Listen() error {
|
func (p *WGEBPFProxy) Listen() error {
|
||||||
pl := portLookup{}
|
pl := portLookup{}
|
||||||
wgPorxyPort, err := pl.searchFreePort()
|
proxyPort, err := pl.searchFreePort()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.proxyPort = proxyPort
|
||||||
|
|
||||||
|
// Prepare IPv4 raw socket (required)
|
||||||
|
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
// Prepare IPv6 raw socket (optional)
|
||||||
|
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||||
|
}
|
||||||
|
if p.rawConnIPv6 != nil {
|
||||||
|
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := net.UDPAddr{
|
addr := net.UDPAddr{
|
||||||
Port: wgPorxyPort,
|
Port: proxyPort,
|
||||||
IP: net.ParseIP(loopbackAddr),
|
IP: net.ParseIP(loopbackAddr),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
p.conn = conn
|
p.conn = conn
|
||||||
|
|
||||||
go p.proxyToRemote()
|
go p.proxyToRemote()
|
||||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
result = multierror.Append(result, err)
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.rawConn.Close(); err != nil {
|
if p.rawConnIPv4 != nil {
|
||||||
result = multierror.Append(result, err)
|
if err := p.rawConnIPv4.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.rawConnIPv6 != nil {
|
||||||
|
if err := p.rawConnIPv6.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns the proxy listening port.
|
||||||
|
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||||
|
return uint16(p.proxyPort)
|
||||||
|
}
|
||||||
|
|
||||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||||
// From this go routine has only one instance.
|
// From this go routine has only one instance.
|
||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
@@ -216,34 +241,3 @@ generatePort:
|
|||||||
}
|
}
|
||||||
return p.lastUsedPort, nil
|
return p.lastUsedPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
|
||||||
payload := gopacket.Payload(data)
|
|
||||||
ipH := &layers.IPv4{
|
|
||||||
DstIP: localHostNetIP,
|
|
||||||
SrcIP: endpointAddr.IP,
|
|
||||||
Version: 4,
|
|
||||||
TTL: 64,
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
udpH := &layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
|
||||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
layerBuffer := gopacket.NewSerializeBuffer()
|
|
||||||
|
|
||||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("serialize layers: %w", err)
|
|
||||||
}
|
|
||||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
|
||||||
return fmt.Errorf("write to raw conn: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,12 +10,89 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||||
|
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||||
|
|
||||||
|
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||||
|
localHostNetIPv6 = net.ParseIP("::1")
|
||||||
|
|
||||||
|
serializeOpts = gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||||
|
type PacketHeaders struct {
|
||||||
|
ipH gopacket.SerializableLayer
|
||||||
|
udpH *layers.UDP
|
||||||
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
localHostAddr net.IP
|
||||||
|
isIPv4 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||||
|
var ipH gopacket.SerializableLayer
|
||||||
|
var networkLayer gopacket.NetworkLayer
|
||||||
|
var localHostAddr net.IP
|
||||||
|
var isIPv4 bool
|
||||||
|
|
||||||
|
// Check if source address is IPv4 or IPv6
|
||||||
|
if endpoint.IP.To4() != nil {
|
||||||
|
// IPv4 path
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
DstIP: localHostNetIPv4,
|
||||||
|
SrcIP: endpoint.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv4
|
||||||
|
networkLayer = ipv4
|
||||||
|
localHostAddr = localHostNetIPv4
|
||||||
|
isIPv4 = true
|
||||||
|
} else {
|
||||||
|
// IPv6 path
|
||||||
|
ipv6 := &layers.IPv6{
|
||||||
|
DstIP: localHostNetIPv6,
|
||||||
|
SrcIP: endpoint.IP,
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv6
|
||||||
|
networkLayer = ipv6
|
||||||
|
localHostAddr = localHostNetIPv6
|
||||||
|
isIPv4 = false
|
||||||
|
}
|
||||||
|
|
||||||
|
udpH := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(endpoint.Port),
|
||||||
|
DstPort: layers.UDPPort(localWGListenPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PacketHeaders{
|
||||||
|
ipH: ipH,
|
||||||
|
udpH: udpH,
|
||||||
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
localHostAddr: localHostAddr,
|
||||||
|
isIPv4: isIPv4,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
type ProxyWrapper struct {
|
type ProxyWrapper struct {
|
||||||
wgeBPFProxy *WGEBPFProxy
|
wgeBPFProxy *WGEBPFProxy
|
||||||
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgRelayedEndpointAddr *net.UDPAddr
|
wgRelayedEndpointAddr *net.UDPAddr
|
||||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
headers *PacketHeaders
|
||||||
|
headerCurrentUsed *PacketHeaders
|
||||||
|
rawConn net.PacketConn
|
||||||
|
|
||||||
paused bool
|
paused bool
|
||||||
pausedCond *sync.Cond
|
pausedCond *sync.Cond
|
||||||
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
|||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
|
||||||
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add turn conn: %w", err)
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create packet sender: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if required raw connection is available
|
||||||
|
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||||
|
return errIPv6ConnNotAvailable
|
||||||
|
}
|
||||||
|
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||||
|
return errIPv4ConnNotAvailable
|
||||||
|
}
|
||||||
|
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.wgRelayedEndpointAddr = addr
|
p.wgRelayedEndpointAddr = addr
|
||||||
return err
|
p.headers = headers
|
||||||
|
p.rawConn = p.selectRawConn(headers)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
|
|||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
p.headerCurrentUsed = p.headers
|
||||||
|
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
@@ -91,10 +188,32 @@ func (p *ProxyWrapper) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
if endpoint == nil || endpoint.IP == nil {
|
||||||
|
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create packet headers: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if required raw connection is available
|
||||||
|
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||||
|
log.Error(errIPv6ConnNotAvailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||||
|
log.Error(errIPv4ConnNotAvailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
p.wgEndpointCurrentUsedAddr = endpoint
|
p.headerCurrentUsed = header
|
||||||
|
p.rawConn = p.selectRawConn(header)
|
||||||
|
|
||||||
p.pausedCond.Signal()
|
p.pausedCond.Signal()
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
@@ -136,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
|||||||
p.pausedCond.Wait()
|
p.pausedCond.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -162,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||||
|
defer func() {
|
||||||
|
if err := header.layerBuffer.Clear(); err != nil {
|
||||||
|
log.Errorf("failed to clear layer buffer: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := gopacket.Payload(data)
|
||||||
|
|
||||||
|
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||||
|
return fmt.Errorf("serialize layers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||||
|
return fmt.Errorf("write to raw conn: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||||
|
if header.isIPv4 {
|
||||||
|
return p.wgeBPFProxy.rawConnIPv4
|
||||||
|
}
|
||||||
|
return p.wgeBPFProxy.rawConnIPv6
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||||
|
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||||
|
if w.ebpfProxy == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return w.ebpfProxy.GetProxyPort()
|
||||||
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
if w.ebpfProxy == nil {
|
if w.ebpfProxy == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
|
|||||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||||
|
func (w *USPFactory) GetProxyPort() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,43 +8,87 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||||
|
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||||
|
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||||
|
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||||
|
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||||
// Create a raw socket.
|
// Create a raw socket.
|
||||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||||
if err != nil {
|
if isIPv4 {
|
||||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind the socket to the "lo" interface.
|
// Bind the socket to the "lo" interface.
|
||||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
// Set the fwmark on the socket.
|
||||||
err = nbnet.SetSocketOpt(fd)
|
err = nbnet.SetSocketOpt(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert the file descriptor to a PacketConn.
|
// Convert the file descriptor to a PacketConn.
|
||||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||||
if file == nil {
|
if file == nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("converting fd to file failed")
|
return nil, fmt.Errorf("converting fd to file failed")
|
||||||
}
|
}
|
||||||
packetConn, err := net.FilePacketConn(file)
|
packetConn, err := net.FilePacketConn(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := file.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close file: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||||
|
if closeErr := file.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||||
|
}
|
||||||
|
|
||||||
return packetConn, nil
|
return packetConn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
353
client/iface/wgproxy/redirect_test.go
Normal file
353
client/iface/wgproxy/redirect_test.go
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||||
|
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||||
|
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||||
|
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||||
|
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||||
|
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return addr1.String() == addr2.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare IP and Port, ignoring zone
|
||||||
|
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51850
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51851
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51852
|
||||||
|
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51853
|
||||||
|
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||||
|
// It verifies that:
|
||||||
|
// 1. Initial traffic from relay connection works
|
||||||
|
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||||
|
// 3. Multiple packets are correctly redirected with the new source address
|
||||||
|
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||||
|
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||||
|
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: wgPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
defer wgListener4.Close()
|
||||||
|
|
||||||
|
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("::1"),
|
||||||
|
Port: wgPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
defer wgListener6.Close()
|
||||||
|
|
||||||
|
// Determine which listener to use based on the NetBird address IP version
|
||||||
|
// (this is where initial traffic will come from before RedirectAs is called)
|
||||||
|
var wgListener *net.UDPConn
|
||||||
|
if p2pEndpoint.IP.To4() == nil {
|
||||||
|
wgListener = wgListener6
|
||||||
|
} else {
|
||||||
|
wgListener = wgListener4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create relay server and connection
|
||||||
|
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0, // Random port
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay server: %v", err)
|
||||||
|
}
|
||||||
|
defer relayServer.Close()
|
||||||
|
|
||||||
|
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay connection: %v", err)
|
||||||
|
}
|
||||||
|
defer relayConn.Close()
|
||||||
|
|
||||||
|
// Add TURN connection to proxy
|
||||||
|
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||||
|
t.Fatalf("failed to add TURN connection: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("failed to close proxy connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start the proxy
|
||||||
|
proxy.Work()
|
||||||
|
|
||||||
|
// Phase 1: Test initial relay traffic
|
||||||
|
msgFromRelay := []byte("hello from relay")
|
||||||
|
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||||
|
t.Fatalf("failed to write to relay server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set read deadline to avoid hanging
|
||||||
|
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _, err := wgListener4.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(msgFromRelay) {
|
||||||
|
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(msgFromRelay) {
|
||||||
|
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: Redirect to p2p endpoint
|
||||||
|
proxy.RedirectAs(p2pEndpoint)
|
||||||
|
|
||||||
|
// Give the proxy a moment to process the redirect
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Phase 3: Test redirected traffic
|
||||||
|
redirectedMessages := [][]byte{
|
||||||
|
[]byte("redirected message 1"),
|
||||||
|
[]byte("redirected message 2"),
|
||||||
|
[]byte("redirected message 3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range redirectedMessages {
|
||||||
|
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||||
|
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify message content
|
||||||
|
if string(buf[:n]) != string(msg) {
|
||||||
|
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify source address matches p2p endpoint (this is the key test)
|
||||||
|
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||||
|
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||||
|
t.Errorf("message %d: expected source address %s, got %s",
|
||||||
|
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||||
|
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||||
|
wgPort := 51856
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create WireGuard listener
|
||||||
|
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: wgPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
defer wgListener.Close()
|
||||||
|
|
||||||
|
// Create relay server and connection
|
||||||
|
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay server: %v", err)
|
||||||
|
}
|
||||||
|
defer relayServer.Close()
|
||||||
|
|
||||||
|
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay connection: %v", err)
|
||||||
|
}
|
||||||
|
defer relayConn.Close()
|
||||||
|
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||||
|
t.Fatalf("failed to add TURN connection: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("failed to close proxy connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy.Work()
|
||||||
|
|
||||||
|
// Test switching between multiple endpoints - using addresses in local subnet
|
||||||
|
endpoints := []*net.UDPAddr{
|
||||||
|
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||||
|
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||||
|
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, endpoint := range endpoints {
|
||||||
|
proxy.RedirectAs(endpoint)
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
msg := []byte("test message")
|
||||||
|
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||||
|
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(msg) {
|
||||||
|
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !compareUDPAddr(srcAddr, endpoint) {
|
||||||
|
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||||
|
i, endpoint.String(), srcAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
|||||||
// the connection is complete, an error is returned. Once successfully
|
// the connection is complete, an error is returned. Once successfully
|
||||||
// connected, any expiration of the context will not affect the
|
// connected, any expiration of the context will not affect the
|
||||||
// connection.
|
// connection.
|
||||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -19,37 +19,56 @@ var (
|
|||||||
FixLengths: true,
|
FixLengths: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
localHostNetIPAddr = &net.IPAddr{
|
localHostNetIPAddrV4 = &net.IPAddr{
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
}
|
}
|
||||||
|
localHostNetIPAddrV6 = &net.IPAddr{
|
||||||
|
IP: net.ParseIP("::1"),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type SrcFaker struct {
|
type SrcFaker struct {
|
||||||
srcAddr *net.UDPAddr
|
srcAddr *net.UDPAddr
|
||||||
|
|
||||||
rawSocket net.PacketConn
|
rawSocket net.PacketConn
|
||||||
ipH gopacket.SerializableLayer
|
ipH gopacket.SerializableLayer
|
||||||
udpH gopacket.SerializableLayer
|
udpH gopacket.SerializableLayer
|
||||||
layerBuffer gopacket.SerializeBuffer
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
localHostAddr *net.IPAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
// Create only the raw socket for the address family we need
|
||||||
|
var rawSocket net.PacketConn
|
||||||
|
var err error
|
||||||
|
var localHostAddr *net.IPAddr
|
||||||
|
|
||||||
|
if srcAddr.IP.To4() != nil {
|
||||||
|
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||||
|
localHostAddr = localHostNetIPAddrV4
|
||||||
|
} else {
|
||||||
|
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||||
|
localHostAddr = localHostNetIPAddrV6
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f := &SrcFaker{
|
f := &SrcFaker{
|
||||||
srcAddr: srcAddr,
|
srcAddr: srcAddr,
|
||||||
rawSocket: rawSocket,
|
rawSocket: rawSocket,
|
||||||
ipH: ipH,
|
ipH: ipH,
|
||||||
udpH: udpH,
|
udpH: udpH,
|
||||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
localHostAddr: localHostAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
return f, nil
|
return f, nil
|
||||||
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||||
}
|
}
|
||||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||||
}
|
}
|
||||||
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||||
ipH := &layers.IPv4{
|
var ipH gopacket.SerializableLayer
|
||||||
DstIP: net.ParseIP("127.0.0.1"),
|
var networkLayer gopacket.NetworkLayer
|
||||||
SrcIP: srcAddr.IP,
|
|
||||||
Version: 4,
|
// Check if source IP is IPv4 or IPv6
|
||||||
TTL: 64,
|
if srcAddr.IP.To4() != nil {
|
||||||
Protocol: layers.IPProtocolUDP,
|
// IPv4
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
DstIP: localHostNetIPAddrV4.IP,
|
||||||
|
SrcIP: srcAddr.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv4
|
||||||
|
networkLayer = ipv4
|
||||||
|
} else {
|
||||||
|
// IPv6
|
||||||
|
ipv6 := &layers.IPv6{
|
||||||
|
DstIP: localHostNetIPAddrV6.IP,
|
||||||
|
SrcIP: srcAddr.IP,
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv6
|
||||||
|
networkLayer = ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
udpH := &layers.UDP{
|
udpH := &layers.UDP{
|
||||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||||
}
|
}
|
||||||
|
|
||||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth manages authentication operations with the management server
|
||||||
|
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||||
|
type Auth struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
client *mgm.GrpcClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
mgmURL *url.URL
|
||||||
|
mgmTLSEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth creates a new Auth instance that manages authentication flows
|
||||||
|
// It establishes a connection to the management server that will be reused for all operations
|
||||||
|
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||||
|
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||||
|
// Validate WireGuard private key
|
||||||
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine TLS setting based on URL scheme
|
||||||
|
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||||
|
|
||||||
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
|
return &Auth{
|
||||||
|
client: mgmClient,
|
||||||
|
config: config,
|
||||||
|
privateKey: myPrivateKey,
|
||||||
|
mgmURL: mgmURL,
|
||||||
|
mgmTLSEnabled: mgmTLSEnabled,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the management client connection
|
||||||
|
func (a *Auth) Close() error {
|
||||||
|
a.mutex.Lock()
|
||||||
|
defer a.mutex.Unlock()
|
||||||
|
|
||||||
|
if a.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||||
|
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||||
|
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||||
|
var supportsSSO bool
|
||||||
|
|
||||||
|
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
// Try PKCE flow first
|
||||||
|
_, err := a.getPKCEFlow(client)
|
||||||
|
if err == nil {
|
||||||
|
supportsSSO = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if PKCE is not supported
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
// PKCE not supported, try Device flow
|
||||||
|
_, err = a.getDeviceFlow(client)
|
||||||
|
if err == nil {
|
||||||
|
supportsSSO = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Device flow is also not supported
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
// Neither PKCE nor Device flow is supported
|
||||||
|
supportsSSO = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Device flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return supportsSSO, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||||
|
// This avoids creating a new connection to the management server
|
||||||
|
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||||
|
var flow OAuthFlow
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
if forceDeviceAuth {
|
||||||
|
flow, err = a.getDeviceFlow(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try PKCE flow first
|
||||||
|
flow, err = a.getPKCEFlow(client)
|
||||||
|
if err != nil {
|
||||||
|
// If PKCE not supported, try Device flow
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
flow, err = a.getDeviceFlow(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return flow, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var needsLogin bool
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
|
if isLoginNeeded(err) {
|
||||||
|
needsLogin = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
needsLogin = false
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return needsLogin, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login attempts to log in or register the client with the management server
|
||||||
|
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var isAuthError bool
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
|
log.Debugf("peer registration required")
|
||||||
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
|
if err != nil {
|
||||||
|
isAuthError = isPermissionDenied(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
isAuthError = isPermissionDenied(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isAuthError = false
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return err, isAuthError
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &PKCEAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
ClientCertPair: a.config.ClientCertKeyPair,
|
||||||
|
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||||
|
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePKCEConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve device flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &DeviceAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
Domain: protoConfig.Domain,
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep compatibility with older management versions
|
||||||
|
if config.Scope == "" {
|
||||||
|
config.Scope = "openid"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateDeviceAuthConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
a.setSystemInfoFlags(sysInfo)
|
||||||
|
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
|
return serverKey, loginResp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||||
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
|
if err != nil && jwtToken == "" {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
a.setSystemInfoFlags(info)
|
||||||
|
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed registering peer %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("peer has been successfully registered on Management Service")
|
||||||
|
|
||||||
|
return loginResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||||
|
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||||
|
info.SetFlags(
|
||||||
|
a.config.RosenpassEnabled,
|
||||||
|
a.config.RosenpassPermissive,
|
||||||
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.DisableClientRoutes,
|
||||||
|
a.config.DisableServerRoutes,
|
||||||
|
a.config.DisableDNS,
|
||||||
|
a.config.DisableFirewall,
|
||||||
|
a.config.BlockLANAccess,
|
||||||
|
a.config.BlockInbound,
|
||||||
|
a.config.LazyConnectionEnabled,
|
||||||
|
a.config.EnableSSHRoot,
|
||||||
|
a.config.EnableSSHSFTP,
|
||||||
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
|
a.config.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconnect closes the current connection and creates a new one
|
||||||
|
// It checks if the brokenClient is still the current client before reconnecting
|
||||||
|
// to avoid multiple threads reconnecting unnecessarily
|
||||||
|
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||||
|
a.mutex.Lock()
|
||||||
|
defer a.mutex.Unlock()
|
||||||
|
|
||||||
|
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||||
|
if a.client != brokenClient {
|
||||||
|
log.Debugf("client already reconnected by another thread, skipping")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new connection FIRST, before closing the old one
|
||||||
|
// This ensures a.client is never nil, preventing panics in other threads
|
||||||
|
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||||
|
// Keep the old client if reconnection fails
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close old connection AFTER new one is successfully created
|
||||||
|
oldClient := a.client
|
||||||
|
a.client = mgmClient
|
||||||
|
|
||||||
|
if oldClient != nil {
|
||||||
|
if err := oldClient.Close(); err != nil {
|
||||||
|
log.Debugf("error closing old connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||||
|
func isConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// These error codes indicate connection issues
|
||||||
|
return s.Code() == codes.Unavailable ||
|
||||||
|
s.Code() == codes.DeadlineExceeded ||
|
||||||
|
s.Code() == codes.Canceled ||
|
||||||
|
s.Code() == codes.Internal
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRetry wraps an operation with exponential backoff retry logic
|
||||||
|
// It automatically reconnects on connection errors
|
||||||
|
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||||
|
backoffSettings := &backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 500 * time.Millisecond,
|
||||||
|
RandomizationFactor: 0.5,
|
||||||
|
Multiplier: 1.5,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 2 * time.Minute,
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}
|
||||||
|
backoffSettings.Reset()
|
||||||
|
|
||||||
|
return backoff.RetryNotify(
|
||||||
|
func() error {
|
||||||
|
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||||
|
a.mutex.RLock()
|
||||||
|
currentClient := a.client
|
||||||
|
a.mutex.RUnlock()
|
||||||
|
|
||||||
|
if currentClient == nil {
|
||||||
|
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute operation with the captured client
|
||||||
|
err := operation(currentClient)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||||
|
if isConnectionError(err) {
|
||||||
|
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||||
|
|
||||||
|
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||||
|
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||||
|
return reconnectErr
|
||||||
|
}
|
||||||
|
// Return the original error to trigger retry with the new connection
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For authentication errors, don't retry
|
||||||
|
if isAuthenticationError(err) {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
backoff.WithContext(backoffSettings, ctx),
|
||||||
|
func(err error, duration time.Duration) {
|
||||||
|
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||||
|
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||||
|
func isAuthenticationError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||||
|
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||||
|
func isPermissionDenied(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLoginNeeded(err error) bool {
|
||||||
|
return isAuthenticationError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegistrationNeeded(err error) bool {
|
||||||
|
return isPermissionDenied(err)
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,12 +25,56 @@ const (
|
|||||||
|
|
||||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||||
|
|
||||||
|
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
|
type DeviceAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Domain An IDP API domain
|
||||||
|
// Deprecated. Use OIDCConfigEndpoint instead
|
||||||
|
Domain string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||||
|
DeviceAuthEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||||
|
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.DeviceAuthEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||||
// for the Device Authorization Flow.
|
// for the Device Authorization Flow.
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
providerConfig internal.DeviceAuthProviderConfig
|
providerConfig DeviceAuthProviderConfig
|
||||||
|
HTTPClient HTTPClient
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
|||||||
return d.providerConfig.ClientID
|
return d.providerConfig.ClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the device authorization flow
|
||||||
|
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
d.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||||
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
|
||||||
tokenResponse, err := d.requestToken(info)
|
tokenResponse, err := d.requestToken(info)
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
err: testCase.inputReqError,
|
err: testCase.inputReqError,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := &DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: expectedAudience,
|
||||||
Audience: expectedAudience,
|
ClientID: expectedClientID,
|
||||||
ClientID: expectedClientID,
|
Scope: expectedScope,
|
||||||
Scope: expectedScope,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
countResBody: testCase.inputCountResBody,
|
countResBody: testCase.inputCountResBody,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: testCase.inputAudience,
|
||||||
Audience: testCase.inputAudience,
|
ClientID: clientID,
|
||||||
ClientID: clientID,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
Scope: "openid",
|
||||||
Scope: "openid",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
|||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
pkceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
return pkceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch s, ok := gstatus.FromError(err); {
|
switch s, ok := gstatus.FromError(err); {
|
||||||
case ok && s.Code() == codes.NotFound:
|
case ok && s.Code() == codes.NotFound:
|
||||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
deviceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
return deviceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
@@ -35,17 +34,67 @@ const (
|
|||||||
defaultPKCETimeoutSeconds = 300
|
defaultPKCETimeoutSeconds = 300
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||||
|
type PKCEAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// RedirectURL handles authorization code from IDP manager
|
||||||
|
RedirectURLs []string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// ClientCertPair is used for mTLS authentication to the IDP
|
||||||
|
ClientCertPair *tls.Certificate
|
||||||
|
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||||
|
DisablePromptLogin bool
|
||||||
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
|
LoginFlag common.LoginFlag
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePKCEConfig validates PKCE provider configuration
|
||||||
|
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.AuthorizationEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||||
|
}
|
||||||
|
if config.RedirectURLs == nil {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||||
// the Authorization Code Flow with PKCE.
|
// the Authorization Code Flow with PKCE.
|
||||||
type PKCEAuthorizationFlow struct {
|
type PKCEAuthorizationFlow struct {
|
||||||
providerConfig internal.PKCEAuthProviderConfig
|
providerConfig PKCEAuthProviderConfig
|
||||||
state string
|
state string
|
||||||
codeVerifier string
|
codeVerifier string
|
||||||
oAuthConfig *oauth2.Config
|
oAuthConfig *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
excludedRanges := getSystemExcludedPortRanges()
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||||
|
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
p.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
|
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tokenChan := make(chan *oauth2.Token, 1)
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
|
|
||||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
defer func() {
|
defer func() {
|
||||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
go p.startServer(server, tokenChan, errChan)
|
go p.startServer(server, tokenChan, errChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case token := <-tokenChan:
|
case token := <-tokenChan:
|
||||||
return p.parseOAuthToken(token)
|
return p.parseOAuthToken(token)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseExcludedPortRanges(t *testing.T) {
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
|||||||
|
|
||||||
availablePort := 65432
|
availablePort := 65432
|
||||||
|
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ func NewConnectClient(
|
|||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
doInitalAutoUpdate bool,
|
doInitalAutoUpdate bool,
|
||||||
|
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -71,8 +70,8 @@ func NewConnectClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||||
return c.run(MobileDependency{}, runningChan)
|
return c.run(MobileDependency{}, runningChan, logPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunOnAndroid with main logic on mobile system
|
// RunOnAndroid with main logic on mobile system
|
||||||
@@ -93,7 +92,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) RunOniOS(
|
func (c *ConnectClient) RunOniOS(
|
||||||
@@ -111,10 +110,10 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rec := c.statusRecorder
|
rec := c.statusRecorder
|
||||||
@@ -284,7 +283,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -472,7 +471,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
|
||||||
nm := false
|
nm := false
|
||||||
if config.NetworkMonitor != nil {
|
if config.NetworkMonitor != nil {
|
||||||
nm = *config.NetworkMonitor
|
nm = *config.NetworkMonitor
|
||||||
@@ -507,7 +506,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
|
|
||||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
|
|
||||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||||
|
LogPath: logPath,
|
||||||
|
|
||||||
|
ProfileConfig: config,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
|
|||||||
@@ -28,8 +28,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -57,6 +59,7 @@ block.prof: Block profiling information.
|
|||||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||||
allocs.prof: Allocations profiling information.
|
allocs.prof: Allocations profiling information.
|
||||||
threadcreate.prof: Thread creation profiling information.
|
threadcreate.prof: Thread creation profiling information.
|
||||||
|
cpu.prof: CPU profiling information.
|
||||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||||
|
|
||||||
|
|
||||||
@@ -223,10 +226,11 @@ type BundleGenerator struct {
|
|||||||
internalConfig *profilemanager.Config
|
internalConfig *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logFile string
|
logPath string
|
||||||
|
cpuProfile []byte
|
||||||
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
|
|
||||||
anonymize bool
|
anonymize bool
|
||||||
clientStatus string
|
|
||||||
includeSystemInfo bool
|
includeSystemInfo bool
|
||||||
logFileCount uint32
|
logFileCount uint32
|
||||||
|
|
||||||
@@ -235,7 +239,6 @@ type BundleGenerator struct {
|
|||||||
|
|
||||||
type BundleConfig struct {
|
type BundleConfig struct {
|
||||||
Anonymize bool
|
Anonymize bool
|
||||||
ClientStatus string
|
|
||||||
IncludeSystemInfo bool
|
IncludeSystemInfo bool
|
||||||
LogFileCount uint32
|
LogFileCount uint32
|
||||||
}
|
}
|
||||||
@@ -244,7 +247,9 @@ type GeneratorDependencies struct {
|
|||||||
InternalConfig *profilemanager.Config
|
InternalConfig *profilemanager.Config
|
||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogFile string
|
LogPath string
|
||||||
|
CPUProfile []byte
|
||||||
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||||
@@ -260,10 +265,11 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
internalConfig: deps.InternalConfig,
|
internalConfig: deps.InternalConfig,
|
||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logFile: deps.LogFile,
|
logPath: deps.LogPath,
|
||||||
|
cpuProfile: deps.CPUProfile,
|
||||||
|
refreshStatus: deps.RefreshStatus,
|
||||||
|
|
||||||
anonymize: cfg.Anonymize,
|
anonymize: cfg.Anonymize,
|
||||||
clientStatus: cfg.ClientStatus,
|
|
||||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||||
logFileCount: logFileCount,
|
logFileCount: logFileCount,
|
||||||
}
|
}
|
||||||
@@ -309,13 +315,6 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
return fmt.Errorf("add status: %w", err)
|
return fmt.Errorf("add status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.statusRecorder != nil {
|
|
||||||
status := g.statusRecorder.GetFullStatus()
|
|
||||||
seedFromStatus(g.anonymizer, &status)
|
|
||||||
} else {
|
|
||||||
log.Debugf("no status recorder available for seeding")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addConfig(); err != nil {
|
if err := g.addConfig(); err != nil {
|
||||||
log.Errorf("failed to add config to debug bundle: %v", err)
|
log.Errorf("failed to add config to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
@@ -332,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addCPUProfile(); err != nil {
|
||||||
|
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.addStackTrace(); err != nil {
|
if err := g.addStackTrace(); err != nil {
|
||||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
@@ -352,7 +355,7 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
|
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||||
if err := g.addLogfile(); err != nil {
|
if err := g.addLogfile(); err != nil {
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
@@ -401,11 +404,30 @@ func (g *BundleGenerator) addReadme() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addStatus() error {
|
func (g *BundleGenerator) addStatus() error {
|
||||||
if status := g.clientStatus; status != "" {
|
if g.statusRecorder != nil {
|
||||||
statusReader := strings.NewReader(status)
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.refreshStatus != nil {
|
||||||
|
g.refreshStatus()
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus := g.statusRecorder.GetFullStatus()
|
||||||
|
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||||
|
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||||
|
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
||||||
|
statusOutput := overview.FullDetailSummary()
|
||||||
|
|
||||||
|
statusReader := strings.NewReader(statusOutput)
|
||||||
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
||||||
return fmt.Errorf("add status file to zip: %w", err)
|
return fmt.Errorf("add status file to zip: %w", err)
|
||||||
}
|
}
|
||||||
|
seedFromStatus(g.anonymizer, &fullStatus)
|
||||||
|
} else {
|
||||||
|
log.Debugf("no status recorder available for seeding")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -535,6 +557,19 @@ func (g *BundleGenerator) addProf() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addCPUProfile() error {
|
||||||
|
if len(g.cpuProfile) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bytes.NewReader(g.cpuProfile)
|
||||||
|
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
|
||||||
|
return fmt.Errorf("add CPU profile to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addStackTrace() error {
|
func (g *BundleGenerator) addStackTrace() error {
|
||||||
buf := make([]byte, 5242880) // 5 MB buffer
|
buf := make([]byte, 5242880) // 5 MB buffer
|
||||||
n := runtime.Stack(buf, true)
|
n := runtime.Stack(buf, true)
|
||||||
@@ -710,14 +745,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addLogfile() error {
|
func (g *BundleGenerator) addLogfile() error {
|
||||||
if g.logFile == "" {
|
if g.logPath == "" {
|
||||||
log.Debugf("skipping empty log file in debug bundle")
|
log.Debugf("skipping empty log file in debug bundle")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logDir := filepath.Dir(g.logFile)
|
logDir := filepath.Dir(g.logPath)
|
||||||
|
|
||||||
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
|
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
|
||||||
return fmt.Errorf("add client log file to zip: %w", err)
|
return fmt.Errorf("add client log file to zip: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
101
client/internal/debug/upload.go
Normal file
101
client/internal/debug/upload.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||||
|
|
||||||
|
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||||
|
response, err := getUploadURL(ctx, url, managementURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = upload(ctx, filePath, response)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return response.Key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||||
|
fileData, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer fileData.Close()
|
||||||
|
|
||||||
|
stat, err := fileData.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stat file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat.Size() > maxBundleUploadSize {
|
||||||
|
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create PUT request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.ContentLength = stat.Size()
|
||||||
|
req.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
|
||||||
|
putResp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upload failed: %v", err)
|
||||||
|
}
|
||||||
|
defer putResp.Body.Close()
|
||||||
|
|
||||||
|
if putResp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(putResp.Body)
|
||||||
|
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||||
|
id := getURLHash(managementURL)
|
||||||
|
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create GET request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(getReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
urlBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response body: %w", err)
|
||||||
|
}
|
||||||
|
var response types.GetURLResponse
|
||||||
|
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getURLHash(url string) string {
|
||||||
|
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
|
|||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
err := os.WriteFile(file, fileContent, 0640)
|
err := os.WriteFile(file, fileContent, 0640)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
id := getURLHash(testURL)
|
id := getURLHash(testURL)
|
||||||
require.Contains(t, key, id+"/")
|
require.Contains(t, key, id+"/")
|
||||||
@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
|||||||
}
|
}
|
||||||
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||||
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||||
if peer.PresharedKey {
|
if peer.PresharedKey != [32]byte{} {
|
||||||
sb.WriteString(" preshared key: (hidden)\n")
|
sb.WriteString(" preshared key: (hidden)\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,136 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
|
||||||
type DeviceAuthorizationFlow struct {
|
|
||||||
Provider string
|
|
||||||
ProviderConfig DeviceAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
|
||||||
type DeviceAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Domain An IDP API domain
|
|
||||||
// Deprecated. Use OIDCConfigEndpoint instead
|
|
||||||
Domain string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
|
||||||
DeviceAuthEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
|
||||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve device flow: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
|
||||||
|
|
||||||
ProviderConfig: DeviceAuthProviderConfig{
|
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
|
||||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
|
||||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep compatibility with older management versions
|
|
||||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
|
||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceAuthorizationFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.Audience == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
|
||||||
}
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.DeviceAuthEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
matchSubdomains: false,
|
matchSubdomains: false,
|
||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD exact match",
|
||||||
|
handlerDomain: "example.x.",
|
||||||
|
queryDomain: "example.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD subdomain match",
|
||||||
|
handlerDomain: "example.x.",
|
||||||
|
queryDomain: "sub.example.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD wildcard match",
|
||||||
|
handlerDomain: "*.example.x.",
|
||||||
|
queryDomain: "sub.example.x.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two letter domain labels",
|
||||||
|
handlerDomain: "a.b.",
|
||||||
|
queryDomain: "a.b.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character domain",
|
||||||
|
handlerDomain: "x.",
|
||||||
|
queryDomain: "x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character domain with subdomain match",
|
||||||
|
handlerDomain: "x.",
|
||||||
|
queryDomain: "sub.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@@ -38,6 +40,9 @@ const (
|
|||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
systemDNSSettings SystemDNSSettings
|
systemDNSSettings SystemDNSSettings
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (*systemConfigurator, error) {
|
func newHostManager() (*systemConfigurator, error) {
|
||||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var dnsSettings SystemDNSSettings
|
var dnsSettings SystemDNSSettings
|
||||||
|
var serverAddresses []netip.Addr
|
||||||
inSearchDomainsArray := false
|
inSearchDomainsArray := false
|
||||||
inServerAddressesArray := false
|
inServerAddressesArray := false
|
||||||
|
|
||||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||||
} else if inServerAddressesArray {
|
} else if inServerAddressesArray {
|
||||||
address := strings.Split(line, " : ")[1]
|
address := strings.Split(line, " : ")[1]
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||||
dnsSettings.ServerIP = ip.Unmap()
|
ip = ip.Unmap()
|
||||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
serverAddresses = append(serverAddresses, ip)
|
||||||
|
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||||
|
dnsSettings.ServerIP = ip
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
// default to 53 port
|
// default to 53 port
|
||||||
dnsSettings.ServerPort = DefaultPort
|
dnsSettings.ServerPort = DefaultPort
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.origNameservers = serverAddresses
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
return dnsSettings, nil
|
return dnsSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return slices.Clone(s.origNameservers)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
err := s.addDNSState(key, domains, ip, port, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
|||||||
_, err := cmd.CombinedOutput()
|
_, err := cmd.CombinedOutput()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameservers(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
origNameservers: []netip.Addr{
|
||||||
|
netip.MustParseAddr("8.8.8.8"),
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
assert.Len(t, servers, 2)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||||
|
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
|
||||||
|
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||||
|
|
||||||
|
for _, server := range servers {
|
||||||
|
assert.True(t, server.IsValid(), "server address should be valid")
|
||||||
|
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
_ = sm.Stop(context.Background())
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return configurator, sm, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
routeAll bool
|
||||||
|
}{
|
||||||
|
{"routeall_false", false},
|
||||||
|
{"routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
for _, srv := range initialServers {
|
||||||
|
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.routeAll,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= 2; i++ {
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initialRoute bool
|
||||||
|
}{
|
||||||
|
{"start_with_routeall_false", false},
|
||||||
|
{"start_with_routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.initialRoute,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// First apply
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle RouteAll
|
||||||
|
config.RouteAll = !tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle back
|
||||||
|
config.RouteAll = tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
for _, srv := range servers {
|
||||||
|
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {}
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
logger := log.WithFields(log.Fields{
|
||||||
|
"request_id": resutil.GetRequestID(w),
|
||||||
|
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||||
|
})
|
||||||
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
logger.Debug("received local resolver request with no question")
|
logger.Debug("received local resolver request with no question")
|
||||||
@@ -120,7 +123,7 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No records found, but domain exists with different record types (NODATA)
|
// No records found, but domain exists with different record types (NODATA)
|
||||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
|
||||||
return dns.RcodeSuccess
|
return dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,11 +167,15 @@ func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dn
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
||||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
defer d.mu.RUnlock()
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
_, exists := d.domains[domainName]
|
_, exists := d.domains[domainName]
|
||||||
|
if !exists && supportsWildcard(qType) {
|
||||||
|
testWild := transformDomainToWildcard(string(domainName))
|
||||||
|
_, exists = d.domains[domain.Domain(testWild)]
|
||||||
|
}
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,6 +202,16 @@ type lookupResult struct {
|
|||||||
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
records, found := d.records[question]
|
records, found := d.records[question]
|
||||||
|
usingWildcard := false
|
||||||
|
wildQuestion := transformToWildcard(question)
|
||||||
|
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
|
||||||
|
// If the domain exists with any record type, return NODATA instead of wildcard match.
|
||||||
|
if !found && supportsWildcard(question.Qtype) {
|
||||||
|
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
|
||||||
|
records, found = d.records[wildQuestion]
|
||||||
|
usingWildcard = found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
@@ -216,18 +233,53 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
|
|||||||
// if there's more than one record, rotate them (round-robin)
|
// if there's more than one record, rotate them (round-robin)
|
||||||
if len(recordsCopy) > 1 {
|
if len(recordsCopy) > 1 {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
records = d.records[question]
|
q := question
|
||||||
|
if usingWildcard {
|
||||||
|
q = wildQuestion
|
||||||
|
}
|
||||||
|
records = d.records[q]
|
||||||
if len(records) > 1 {
|
if len(records) > 1 {
|
||||||
first := records[0]
|
first := records[0]
|
||||||
records = append(records[1:], first)
|
records = append(records[1:], first)
|
||||||
d.records[question] = records
|
d.records[q] = records
|
||||||
}
|
}
|
||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if usingWildcard {
|
||||||
|
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
|
||||||
|
}
|
||||||
|
|
||||||
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func transformToWildcard(question dns.Question) dns.Question {
|
||||||
|
wildQuestion := question
|
||||||
|
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
|
||||||
|
return wildQuestion
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformDomainToWildcard(domain string) string {
|
||||||
|
s := strings.Split(domain, ".")
|
||||||
|
s[0] = "*"
|
||||||
|
return strings.Join(s, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func supportsWildcard(queryType uint16) bool {
|
||||||
|
return queryType != dns.TypeNS && queryType != dns.TypeSOA
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
|
||||||
|
records := make([]dns.RR, len(wildRecords))
|
||||||
|
for i, record := range wildRecords {
|
||||||
|
copiedRecord := dns.Copy(record)
|
||||||
|
copiedRecord.Header().Name = originalName
|
||||||
|
records[i] = copiedRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||||
|
}
|
||||||
|
|
||||||
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||||
// the final resolved record of the requested type. This is required for musl libc
|
// the final resolved record of the requested type. This is required for musl libc
|
||||||
// compatibility, which expects the full answer chain rather than just the CNAME.
|
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||||
@@ -237,6 +289,13 @@ func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Questio
|
|||||||
|
|
||||||
for range maxDepth {
|
for range maxDepth {
|
||||||
cnameRecords := d.getRecords(cnameQuestion)
|
cnameRecords := d.getRecords(cnameQuestion)
|
||||||
|
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
|
||||||
|
wildQuestion := transformToWildcard(cnameQuestion)
|
||||||
|
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
|
||||||
|
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(cnameRecords) == 0 {
|
if len(cnameRecords) == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -303,7 +362,7 @@ func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targ
|
|||||||
}
|
}
|
||||||
|
|
||||||
// domain exists locally but not this record type (NODATA)
|
// domain exists locally but not this record type (NODATA)
|
||||||
if d.hasRecordsForDomain(domain.Domain(targetName)) {
|
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
|
||||||
return lookupResult{rcode: dns.RcodeSuccess}
|
return lookupResult{rcode: dns.RcodeSuccess}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
s.registerFallback(config)
|
s.registerFallback(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
|
|
||||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||||
if len(originalNameservers) == 0 {
|
if len(originalNameservers) == 0 {
|
||||||
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,21 @@ import (
|
|||||||
|
|
||||||
type MockResponseWriter struct {
|
type MockResponseWriter struct {
|
||||||
WriteMsgFunc func(m *dns.Msg) error
|
WriteMsgFunc func(m *dns.Msg) error
|
||||||
|
lastResponse *dns.Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||||
|
rw.lastResponse = m
|
||||||
if rw.WriteMsgFunc != nil {
|
if rw.WriteMsgFunc != nil {
|
||||||
return rw.WriteMsgFunc(m)
|
return rw.WriteMsgFunc(m)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||||
|
return rw.lastResponse
|
||||||
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
|||||||
@@ -71,6 +71,11 @@ type upstreamResolverBase struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type upstreamFailure struct {
|
||||||
|
upstream netip.AddrPort
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
@@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
logger := log.WithFields(log.Fields{
|
||||||
|
"request_id": resutil.GetRequestID(w),
|
||||||
|
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||||
|
})
|
||||||
|
|
||||||
u.prepareRequest(r)
|
u.prepareRequest(r)
|
||||||
|
|
||||||
@@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.tryUpstreamServers(w, r, logger) {
|
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||||
return
|
if len(failures) > 0 {
|
||||||
|
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
u.writeErrorResponse(w, r, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.writeErrorResponse(w, r, logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||||
@@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
timeout := u.upstreamTimeout
|
timeout := u.upstreamTimeout
|
||||||
if len(u.upstreamServers) > 1 {
|
if len(u.upstreamServers) > 1 {
|
||||||
maxTotal := 5 * time.Second
|
maxTotal := 5 * time.Second
|
||||||
@@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var failures []upstreamFailure
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
if u.queryUpstream(w, r, upstream, timeout, logger) {
|
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||||
return true
|
failures = append(failures, *failure)
|
||||||
|
} else {
|
||||||
|
return true, failures
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false, failures
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
|
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||||
|
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||||
var rm *dns.Msg
|
var rm *dns.Msg
|
||||||
var t time.Duration
|
var t time.Duration
|
||||||
var err error
|
var err error
|
||||||
@@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
|
return u.handleUpstreamError(err, upstream, startTime)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
if rm == nil || !rm.Response {
|
||||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||||
|
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||||
|
}
|
||||||
|
|
||||||
|
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
|
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
|
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
elapsed := time.Since(startTime)
|
elapsed := time.Since(startTime)
|
||||||
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
|
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
|
||||||
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
||||||
timeoutMsg += " " + peerInfo
|
reason += " " + peerInfo
|
||||||
}
|
}
|
||||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
logger.Warn(timeoutMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
@@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
totalUpstreams := len(u.upstreamServers)
|
||||||
|
failedCount := len(failures)
|
||||||
|
failureSummary := formatFailures(failures)
|
||||||
|
|
||||||
|
if succeeded {
|
||||||
|
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||||
|
} else {
|
||||||
|
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(r, dns.RcodeServerFailure)
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
if err := w.WriteMsg(m); err != nil {
|
if err := w.WriteMsg(m); err != nil {
|
||||||
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatFailures(failures []upstreamFailure) string {
|
||||||
|
parts := make([]string, 0, len(failures))
|
||||||
|
for _, f := range failures {
|
||||||
|
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
// ProbeAvailability tests all upstream servers simultaneously and
|
// ProbeAvailability tests all upstream servers simultaneously and
|
||||||
// disables the resolver if none work
|
// disables the resolver if none work
|
||||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||||
@@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
|||||||
return reply, nil
|
return reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||||
func FormatPeerStatus(peerState *peer.State) string {
|
func FormatPeerStatus(peerState *peer.State) string {
|
||||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -9,6 +10,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
|
|||||||
return c.r, c.rtt, c.err
|
return c.r, c.rtt, c.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockUpstreamResponse struct {
|
||||||
|
msg *dns.Msg
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockUpstreamResolverPerServer struct {
|
||||||
|
responses map[string]mockUpstreamResponse
|
||||||
|
rtt time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||||
|
if r, ok := c.responses[upstream]; ok {
|
||||||
|
return r.msg, c.rtt, r.err
|
||||||
|
}
|
||||||
|
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||||
mockClient := &mockUpstreamResolver{
|
mockClient := &mockUpstreamResolver{
|
||||||
err: dns.ErrTime,
|
err: dns.ErrTime,
|
||||||
@@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
t.Errorf("should be enabled")
|
t.Errorf("should be enabled")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||||
|
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|
||||||
|
successAnswer := "192.0.2.100"
|
||||||
|
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
upstream1 mockUpstreamResponse
|
||||||
|
upstream2 mockUpstreamResponse
|
||||||
|
expectedRcode int
|
||||||
|
expectAnswer bool
|
||||||
|
expectTrySecond bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success on first upstream",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
expectAnswer: true,
|
||||||
|
expectTrySecond: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SERVFAIL from first should try second",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
expectAnswer: true,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "REFUSED from first should try second",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
expectAnswer: true,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NXDOMAIN from first should NOT try second",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
expectedRcode: dns.RcodeNameError,
|
||||||
|
expectAnswer: false,
|
||||||
|
expectTrySecond: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout from first should try second",
|
||||||
|
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
expectAnswer: true,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no response from first should try second",
|
||||||
|
upstream1: mockUpstreamResponse{msg: nil},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
expectAnswer: true,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both upstreams return SERVFAIL",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectAnswer: false,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both upstreams timeout",
|
||||||
|
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectAnswer: false,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first SERVFAIL then timeout",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectAnswer: false,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first timeout then SERVFAIL",
|
||||||
|
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectAnswer: false,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first REFUSED then SERVFAIL",
|
||||||
|
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||||
|
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectAnswer: false,
|
||||||
|
expectTrySecond: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var queriedUpstreams []string
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
upstream1.String(): tc.upstream1,
|
||||||
|
upstream2.String(): tc.upstream2,
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
trackingClient := &trackingMockClient{
|
||||||
|
inner: mockClient,
|
||||||
|
queriedUpstreams: &queriedUpstreams,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: trackingClient,
|
||||||
|
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||||
|
resolver.ServeDNS(responseWriter, inputMSG)
|
||||||
|
|
||||||
|
require.NotNil(t, responseMSG, "should write a response")
|
||||||
|
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
|
||||||
|
|
||||||
|
if tc.expectAnswer {
|
||||||
|
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
|
||||||
|
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectTrySecond {
|
||||||
|
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
|
||||||
|
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||||
|
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
|
||||||
|
} else {
|
||||||
|
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
|
||||||
|
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackingMockClient struct {
|
||||||
|
inner *mockUpstreamResolverPerServer
|
||||||
|
queriedUpstreams *[]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||||
|
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
|
||||||
|
return t.inner.exchange(ctx, upstream, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMockResponse(rcode int, answer string) *dns.Msg {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.Response = true
|
||||||
|
m.Rcode = rcode
|
||||||
|
|
||||||
|
if rcode == dns.RcodeSuccess && answer != "" {
|
||||||
|
m.Answer = []dns.RR{
|
||||||
|
&dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: "example.com.",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
A: net.ParseIP(answer),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||||
|
upstream := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamServers: []netip.AddrPort{upstream},
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||||
|
resolver.ServeDNS(responseWriter, inputMSG)
|
||||||
|
|
||||||
|
require.NotNil(t, responseMSG, "should write a response")
|
||||||
|
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatFailures(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
failures []upstreamFailure
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty slice",
|
||||||
|
failures: []upstreamFailure{},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single failure",
|
||||||
|
failures: []upstreamFailure{
|
||||||
|
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||||
|
},
|
||||||
|
expected: "8.8.8.8:53=SERVFAIL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple failures",
|
||||||
|
failures: []upstreamFailure{
|
||||||
|
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||||
|
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
|
||||||
|
},
|
||||||
|
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := formatFailures(tc.failures)
|
||||||
|
assert.Equal(t, tc.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
@@ -42,12 +43,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/jobexec"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
@@ -132,6 +136,16 @@ type EngineConfig struct {
|
|||||||
LazyConnectionEnabled bool
|
LazyConnectionEnabled bool
|
||||||
|
|
||||||
MTU uint16
|
MTU uint16
|
||||||
|
|
||||||
|
// for debug bundle generation
|
||||||
|
ProfileConfig *profilemanager.Config
|
||||||
|
|
||||||
|
LogPath string
|
||||||
|
|
||||||
|
// ProxyConfig contains system proxy settings for macOS
|
||||||
|
ProxyEnabled bool
|
||||||
|
ProxyHost string
|
||||||
|
ProxyPort int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -195,7 +209,8 @@ type Engine struct {
|
|||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
// Sync response persistence
|
// Sync response persistence (protected by syncRespMux)
|
||||||
|
syncRespMux sync.RWMutex
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
@@ -211,6 +226,12 @@ type Engine struct {
|
|||||||
shutdownWg sync.WaitGroup
|
shutdownWg sync.WaitGroup
|
||||||
|
|
||||||
probeStunTurn *relay.StunTurnProbe
|
probeStunTurn *relay.StunTurnProbe
|
||||||
|
|
||||||
|
jobExecutor *jobexec.Executor
|
||||||
|
jobExecutorWG sync.WaitGroup
|
||||||
|
|
||||||
|
// proxyManager manages system-wide browser proxy settings on macOS
|
||||||
|
proxyManager *proxy.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -224,7 +245,18 @@ type localIpUpdater interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine with probes attached
|
// NewEngine creates a new Connection Engine with probes attached
|
||||||
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
|
func NewEngine(
|
||||||
|
clientCtx context.Context,
|
||||||
|
clientCancel context.CancelFunc,
|
||||||
|
signalClient signal.Client,
|
||||||
|
mgmClient mgm.Client,
|
||||||
|
relayManager *relayClient.Manager,
|
||||||
|
config *EngineConfig,
|
||||||
|
mobileDep MobileDependency,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
checks []*mgmProto.Checks,
|
||||||
|
stateManager *statemanager.Manager,
|
||||||
|
) *Engine {
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
@@ -244,6 +276,7 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
|
|||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
|
jobExecutor: jobexec.NewExecutor(),
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
@@ -289,6 +322,12 @@ func (e *Engine) Stop() error {
|
|||||||
e.updateManager.Stop()
|
e.updateManager.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.proxyManager != nil {
|
||||||
|
if err := e.proxyManager.DisableWebProxy(); err != nil {
|
||||||
|
log.Warnf("failed to disable system proxy: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Info("cleaning up status recorder states")
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
@@ -312,6 +351,8 @@ func (e *Engine) Stop() error {
|
|||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
// stop flow manager after wg interface is gone
|
// stop flow manager after wg interface is gone
|
||||||
@@ -422,6 +463,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
e.stateManager.Start()
|
e.stateManager.Start()
|
||||||
|
|
||||||
|
// Initialize proxy manager and register state for cleanup
|
||||||
|
proxy.RegisterState(e.stateManager)
|
||||||
|
e.proxyManager = proxy.NewManager(e.stateManager)
|
||||||
|
|
||||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
@@ -479,6 +524,15 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return fmt.Errorf("up wg interface: %w", err)
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up notrack rules immediately after proxy is listening to prevent
|
||||||
|
// conntrack entries from being created before the rules are in place
|
||||||
|
e.setupWGProxyNoTrack()
|
||||||
|
|
||||||
|
// Set the WireGuard interface for rosenpass after interface is up
|
||||||
|
if e.rpManager != nil {
|
||||||
|
e.rpManager.SetInterface(e.wgInterface)
|
||||||
|
}
|
||||||
|
|
||||||
// if inbound conns are blocked there is no need to create the ACL manager
|
// if inbound conns are blocked there is no need to create the ACL manager
|
||||||
if e.firewall != nil && !e.config.BlockInbound {
|
if e.firewall != nil && !e.config.BlockInbound {
|
||||||
e.acl = acl.NewDefaultManager(e.firewall)
|
e.acl = acl.NewDefaultManager(e.firewall)
|
||||||
@@ -500,6 +554,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
|
e.receiveJobEvents()
|
||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
@@ -537,9 +592,11 @@ func (e *Engine) createFirewall() error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
return fmt.Errorf("create firewall manager: %w", err)
|
||||||
return nil
|
}
|
||||||
|
if e.firewall == nil {
|
||||||
|
return fmt.Errorf("create firewall manager: received nil manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.initFirewall(); err != nil {
|
if err := e.initFirewall(); err != nil {
|
||||||
@@ -585,6 +642,23 @@ func (e *Engine) initFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||||
|
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||||
|
func (e *Engine) setupWGProxyNoTrack() {
|
||||||
|
if e.firewall == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPort := e.wgInterface.GetProxyPort()
|
||||||
|
if proxyPort == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||||
|
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) blockLanAccess() {
|
func (e *Engine) blockLanAccess() {
|
||||||
if e.config.BlockInbound {
|
if e.config.BlockInbound {
|
||||||
// no need to set up extra deny rules if inbound is already blocked in general
|
// no need to set up extra deny rules if inbound is already blocked in general
|
||||||
@@ -828,9 +902,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||||
|
// Read the storage-enabled flag under the syncRespMux too.
|
||||||
|
e.syncRespMux.RLock()
|
||||||
|
enabled := e.persistSyncResponse
|
||||||
|
e.syncRespMux.RUnlock()
|
||||||
|
|
||||||
// Store sync response if persistence is enabled
|
// Store sync response if persistence is enabled
|
||||||
if e.persistSyncResponse {
|
if enabled {
|
||||||
|
e.syncRespMux.Lock()
|
||||||
e.latestSyncResponse = update
|
e.latestSyncResponse = update
|
||||||
|
e.syncRespMux.Unlock()
|
||||||
|
|
||||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -960,6 +1043,80 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (e *Engine) receiveJobEvents() {
|
||||||
|
e.jobExecutorWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.jobExecutorWG.Done()
|
||||||
|
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
|
||||||
|
resp := mgmProto.JobResponse{
|
||||||
|
ID: msg.ID,
|
||||||
|
Status: mgmProto.JobStatus_failed,
|
||||||
|
}
|
||||||
|
switch params := msg.WorkloadParameters.(type) {
|
||||||
|
case *mgmProto.JobRequest_Bundle:
|
||||||
|
bundleResult, err := e.handleBundle(params.Bundle)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("handling bundle: %v", err)
|
||||||
|
resp.Reason = []byte(err.Error())
|
||||||
|
return &resp
|
||||||
|
}
|
||||||
|
resp.Status = mgmProto.JobStatus_succeeded
|
||||||
|
resp.WorkloadResults = bundleResult
|
||||||
|
return &resp
|
||||||
|
default:
|
||||||
|
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
|
||||||
|
return &resp
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// happens if management is unavailable for a long time.
|
||||||
|
// We want to cancel the operation of the whole client
|
||||||
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
|
e.clientCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info("stopped receiving jobs from Management Service")
|
||||||
|
}()
|
||||||
|
log.Info("connecting to Management Service jobs stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
|
||||||
|
log.Infof("handle remote debug bundle request: %s", params.String())
|
||||||
|
syncResponse, err := e.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("get latest sync response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleDeps := debug.GeneratorDependencies{
|
||||||
|
InternalConfig: e.config.ProfileConfig,
|
||||||
|
StatusRecorder: e.statusRecorder,
|
||||||
|
SyncResponse: syncResponse,
|
||||||
|
LogPath: e.config.LogPath,
|
||||||
|
RefreshStatus: func() {
|
||||||
|
e.RunHealthProbes(true)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleJobParams := debug.BundleConfig{
|
||||||
|
Anonymize: params.Anonymize,
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
LogFileCount: uint32(params.LogFileCount),
|
||||||
|
}
|
||||||
|
|
||||||
|
waitFor := time.Duration(params.BundleForTime) * time.Minute
|
||||||
|
|
||||||
|
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &mgmProto.JobResponse_Bundle{
|
||||||
|
Bundle: &mgmProto.BundleResult{
|
||||||
|
UploadKey: uploadKey,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
@@ -1174,6 +1331,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||||
e.dnsServer.ProbeAvailability()
|
e.dnsServer.ProbeAvailability()
|
||||||
|
|
||||||
|
// Update system proxy state based on routes after network map is fully applied
|
||||||
|
e.updateSystemProxy(clientRoutes)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1405,6 +1565,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
if e.rpManager != nil {
|
if e.rpManager != nil {
|
||||||
peerConn.SetOnConnected(e.rpManager.OnConnected)
|
peerConn.SetOnConnected(e.rpManager.OnConnected)
|
||||||
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
|
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
|
||||||
|
peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized)
|
||||||
}
|
}
|
||||||
|
|
||||||
return peerConn, nil
|
return peerConn, nil
|
||||||
@@ -1528,6 +1689,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
|
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||||
@@ -1714,7 +1876,7 @@ func (e *Engine) getRosenpassAddr() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
|
||||||
// and updates the status recorder with the latest states.
|
// and updates the status recorder with the latest states.
|
||||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
@@ -1728,23 +1890,8 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
|||||||
stuns := slices.Clone(e.STUNs)
|
stuns := slices.Clone(e.STUNs)
|
||||||
turns := slices.Clone(e.TURNs)
|
turns := slices.Clone(e.TURNs)
|
||||||
|
|
||||||
if e.wgInterface != nil {
|
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||||
stats, err := e.wgInterface.GetStats()
|
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get wireguard stats: %v", err)
|
|
||||||
e.syncMsgMux.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, key := range e.peerStore.PeersPubKey() {
|
|
||||||
// wgStats could be zero value, in which case we just reset the stats
|
|
||||||
wgStats, ok := stats[key]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
|
||||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
@@ -1848,8 +1995,8 @@ func (e *Engine) stopDNSServer() {
|
|||||||
|
|
||||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||||
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||||
e.syncMsgMux.Lock()
|
e.syncRespMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncRespMux.Unlock()
|
||||||
|
|
||||||
if enabled == e.persistSyncResponse {
|
if enabled == e.persistSyncResponse {
|
||||||
return
|
return
|
||||||
@@ -1864,20 +2011,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
|||||||
|
|
||||||
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
||||||
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||||
e.syncMsgMux.Lock()
|
e.syncRespMux.RLock()
|
||||||
defer e.syncMsgMux.Unlock()
|
enabled := e.persistSyncResponse
|
||||||
|
latest := e.latestSyncResponse
|
||||||
|
e.syncRespMux.RUnlock()
|
||||||
|
|
||||||
if !e.persistSyncResponse {
|
if !enabled {
|
||||||
return nil, errors.New("sync response persistence is disabled")
|
return nil, errors.New("sync response persistence is disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.latestSyncResponse == nil {
|
if latest == nil {
|
||||||
//nolint:nilnil
|
//nolint:nilnil
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
|
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
|
||||||
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
|
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("failed to clone sync response")
|
return nil, fmt.Errorf("failed to clone sync response")
|
||||||
}
|
}
|
||||||
@@ -2176,6 +2325,26 @@ func createFile(path string) error {
|
|||||||
return file.Close()
|
return file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateSystemProxy triggers a proxy enable/disable cycle after the network map is updated.
|
||||||
|
func (e *Engine) updateSystemProxy(clientRoutes route.HAMap) {
|
||||||
|
if runtime.GOOS != "darwin" || e.proxyManager == nil {
|
||||||
|
log.Errorf("not updating proxy")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.proxyManager.EnableWebProxy(e.config.ProxyHost, e.config.ProxyPort); err != nil {
|
||||||
|
log.Errorf("enable system proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("system proxy enabled after network map update")
|
||||||
|
|
||||||
|
if err := e.proxyManager.DisableWebProxy(); err != nil {
|
||||||
|
log.Errorf("disable system proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("system proxy disabled after network map update")
|
||||||
|
}
|
||||||
|
|
||||||
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
@@ -106,6 +107,7 @@ type MockWGIface struct {
|
|||||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
|
GetProxyPortFunc func() uint16
|
||||||
GetNetFunc func() *netstack.Net
|
GetNetFunc func() *netstack.Net
|
||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
@@ -202,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
|||||||
return m.GetProxyFunc()
|
return m.GetProxyFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||||
|
if m.GetProxyPortFunc != nil {
|
||||||
|
return m.GetProxyPortFunc()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||||
return m.GetNetFunc()
|
return m.GetNetFunc()
|
||||||
}
|
}
|
||||||
@@ -213,6 +222,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
_ = util.InitLog("debug", util.LogConsole)
|
_ = util.InitLog("debug", util.LogConsole)
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
@@ -1599,6 +1612,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||||
|
|
||||||
@@ -1622,7 +1636,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -1631,7 +1645,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
|||||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
|
GetProxyPort() uint16
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemoveEndpointAddress(key string) error
|
RemoveEndpointAddress(key string) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
@@ -42,4 +43,5 @@ type wgIfaceBase interface {
|
|||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
|
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsLoginRequired check that the server is support SSO or not
|
|
||||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
|
||||||
mgmURL := config.ManagementURL
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if isLoginNeeded(err) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login or register the client
|
|
||||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
|
||||||
log.Debugf("peer registration required")
|
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return mgmClient, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
|
||||||
sysInfo.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
|
||||||
return serverKey, loginResp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
|
||||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
|
||||||
if err != nil && jwtToken == "" {
|
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("sending peer registration request to Management Service")
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
info.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed registering peer %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("peer has been successfully registered on Management Service")
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLoginNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRegistrationNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
|
|||||||
|
|
||||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Network monitor: skipping in netstack mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
nw.mu.Lock()
|
nw.mu.Lock()
|
||||||
if nw.cancel != nil {
|
if nw.cancel != nil {
|
||||||
nw.mu.Unlock()
|
nw.mu.Unlock()
|
||||||
|
|||||||
@@ -88,8 +88,9 @@ type Conn struct {
|
|||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||||
onDisconnected func(remotePeer string)
|
onDisconnected func(remotePeer string)
|
||||||
|
rosenpassInitializedPresharedKeyValidator func(peerKey string) bool
|
||||||
|
|
||||||
statusRelay *worker.AtomicWorkerStatus
|
statusRelay *worker.AtomicWorkerStatus
|
||||||
statusICE *worker.AtomicWorkerStatus
|
statusICE *worker.AtomicWorkerStatus
|
||||||
@@ -98,7 +99,10 @@ type Conn struct {
|
|||||||
|
|
||||||
workerICE *WorkerICE
|
workerICE *WorkerICE
|
||||||
workerRelay *WorkerRelay
|
workerRelay *WorkerRelay
|
||||||
wgWatcherWg sync.WaitGroup
|
|
||||||
|
wgWatcher *WGWatcher
|
||||||
|
wgWatcherWg sync.WaitGroup
|
||||||
|
wgWatcherCancel context.CancelFunc
|
||||||
|
|
||||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||||
rosenpassRemoteKey []byte
|
rosenpassRemoteKey []byte
|
||||||
@@ -126,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
|
|
||||||
connLog := log.WithField("peer", config.Key)
|
connLog := log.WithField("peer", config.Key)
|
||||||
|
|
||||||
|
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||||
var conn = &Conn{
|
var conn = &Conn{
|
||||||
Log: connLog,
|
Log: connLog,
|
||||||
config: config,
|
config: config,
|
||||||
@@ -137,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
semaphore: services.Semaphore,
|
semaphore: services.Semaphore,
|
||||||
statusRelay: worker.NewAtomicStatus(),
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
statusICE: worker.NewAtomicStatus(),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
dumpState: dumpState,
|
||||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||||
|
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@@ -162,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
@@ -231,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.Log.Infof("close peer connection")
|
conn.Log.Infof("close peer connection")
|
||||||
conn.ctxCancel()
|
conn.ctxCancel()
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
if conn.wgWatcherCancel != nil {
|
||||||
|
conn.wgWatcherCancel()
|
||||||
|
}
|
||||||
conn.workerRelay.CloseConn()
|
conn.workerRelay.CloseConn()
|
||||||
conn.workerICE.Close()
|
conn.workerICE.Close()
|
||||||
|
|
||||||
@@ -289,6 +297,13 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
|||||||
conn.onDisconnected = handler
|
conn.onDisconnected = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over
|
||||||
|
// PSK management for a peer. When this returns true, presharedKey() returns nil
|
||||||
|
// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK.
|
||||||
|
func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) {
|
||||||
|
conn.rosenpassInitializedPresharedKeyValidator = handler
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
|
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
|
||||||
conn.dumpState.RemoteOffer()
|
conn.dumpState.RemoteOffer()
|
||||||
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||||
@@ -366,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
ep = directEp
|
ep = directEp
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
|
||||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
conn.wgProxyRelay.Pause()
|
conn.wgProxyRelay.Pause()
|
||||||
}
|
}
|
||||||
@@ -378,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||||
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
|
||||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||||
conn.handleConfigurationFailure(err, wgProxy)
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
@@ -423,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.wgWatcherWg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer conn.wgWatcherWg.Done()
|
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
|
||||||
}()
|
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
} else {
|
} else {
|
||||||
@@ -444,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
}
|
}
|
||||||
conn.statusICE.SetDisconnected()
|
conn.statusICE.SetDisconnected()
|
||||||
|
|
||||||
|
conn.disableWgWatcherIfNeeded()
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.evalStatus(),
|
ConnStatus: conn.evalStatus(),
|
||||||
Relayed: conn.isRelayed(),
|
Relayed: conn.isRelayed(),
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
}
|
}
|
||||||
|
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
|
||||||
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
|
||||||
if err != nil {
|
|
||||||
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -492,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||||
|
|
||||||
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
|
||||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
@@ -500,12 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.wgWatcherWg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer conn.wgWatcherWg.Done()
|
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
@@ -519,7 +525,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
func (conn *Conn) onRelayDisconnected() {
|
func (conn *Conn) onRelayDisconnected() {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
conn.handleRelayDisconnectedLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
|
||||||
|
func (conn *Conn) handleRelayDisconnectedLocked() {
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -545,6 +555,8 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
}
|
}
|
||||||
conn.statusRelay.SetDisconnected()
|
conn.statusRelay.SetDisconnected()
|
||||||
|
|
||||||
|
conn.disableWgWatcherIfNeeded()
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.evalStatus(),
|
ConnStatus: conn.evalStatus(),
|
||||||
@@ -563,6 +575,28 @@ func (conn *Conn) onGuardEvent() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) onWGDisconnected() {
|
||||||
|
conn.mu.Lock()
|
||||||
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
if conn.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
|
||||||
|
|
||||||
|
// Close the active connection based on current priority
|
||||||
|
switch conn.currentConnPriority {
|
||||||
|
case conntype.Relay:
|
||||||
|
conn.workerRelay.CloseConn()
|
||||||
|
conn.handleRelayDisconnectedLocked()
|
||||||
|
case conntype.ICEP2P, conntype.ICETurn:
|
||||||
|
conn.workerICE.Close()
|
||||||
|
default:
|
||||||
|
conn.Log.Debugf("No active connection to close on WG timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -689,6 +723,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) enableWgWatcherIfNeeded() {
|
||||||
|
if !conn.wgWatcher.IsEnabled() {
|
||||||
|
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||||
|
conn.wgWatcherCancel = wgWatcherCancel
|
||||||
|
conn.wgWatcherWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer conn.wgWatcherWg.Done()
|
||||||
|
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) disableWgWatcherIfNeeded() {
|
||||||
|
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
|
||||||
|
conn.wgWatcherCancel()
|
||||||
|
conn.wgWatcherCancel = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||||
udpAddr := &net.UDPAddr{
|
udpAddr := &net.UDPAddr{
|
||||||
@@ -759,10 +812,24 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
return conn.config.WgConfig.PreSharedKey
|
return conn.config.WgConfig.PreSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If Rosenpass has already set a PSK for this peer, return nil to prevent
|
||||||
|
// UpdatePeer from overwriting the Rosenpass-managed key.
|
||||||
|
if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to
|
||||||
|
// Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum
|
||||||
|
// key is cryptographically bound to the original secret.
|
||||||
|
if conn.config.WgConfig.PreSharedKey != nil {
|
||||||
|
return conn.config.WgConfig.PreSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to deterministic key if no NetBird PSK is configured
|
||||||
determKey, err := conn.rosenpassDetermKey()
|
determKey, err := conn.rosenpassDetermKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||||
return conn.config.WgConfig.PreSharedKey
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return determKey
|
return determKey
|
||||||
|
|||||||
@@ -284,3 +284,27 @@ func TestConn_presharedKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConn_presharedKey_RosenpassManaged(t *testing.T) {
|
||||||
|
conn := Conn{
|
||||||
|
config: ConnConfig{
|
||||||
|
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// When Rosenpass has already initialized the PSK for this peer,
|
||||||
|
// presharedKey must return nil to avoid UpdatePeer overwriting it.
|
||||||
|
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true }
|
||||||
|
if k := conn.presharedKey([]byte("remote")); k != nil {
|
||||||
|
t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When Rosenpass hasn't taken over yet, presharedKey should provide
|
||||||
|
// a non-nil initial key (deterministic or from NetBird PSK).
|
||||||
|
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false }
|
||||||
|
if k := conn.presharedKey([]byte("remote")); k == nil {
|
||||||
|
t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1145,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
|||||||
return d.wgIface.FullStats()
|
return d.wgIface.FullStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
|
||||||
|
// and updates the cached peer states. This ensures accurate handshake times and
|
||||||
|
// transfer statistics in status reports without running full health probes.
|
||||||
|
func (d *Status) RefreshWireGuardStats() error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
if d.wgIface == nil {
|
||||||
|
return nil // silently skip if interface not set
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := d.wgIface.FullStats()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get wireguard stats: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update each peer's WireGuard statistics
|
||||||
|
for _, peerStats := range stats.Peers {
|
||||||
|
peerState, ok := d.peers[peerStats.PublicKey]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.LastWireguardHandshake = peerStats.LastHandshake
|
||||||
|
peerState.BytesRx = peerStats.RxBytes
|
||||||
|
peerState.BytesTx = peerStats.TxBytes
|
||||||
|
d.peers[peerStats.PublicKey] = peerState
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type EventQueue struct {
|
type EventQueue struct {
|
||||||
maxSize int
|
maxSize int
|
||||||
events []*proto.SystemEvent
|
events []*proto.SystemEvent
|
||||||
|
|||||||
@@ -30,10 +30,8 @@ type WGWatcher struct {
|
|||||||
peerKey string
|
peerKey string
|
||||||
stateDump *stateDump
|
stateDump *stateDump
|
||||||
|
|
||||||
ctx context.Context
|
enabled bool
|
||||||
ctxCancel context.CancelFunc
|
muEnabled sync.RWMutex
|
||||||
ctxLock sync.Mutex
|
|
||||||
enabledTime time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||||
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
||||||
w.log.Debugf("enable WireGuard watcher")
|
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
|
||||||
w.ctxLock.Lock()
|
w.muEnabled.Lock()
|
||||||
w.enabledTime = time.Now()
|
if w.enabled {
|
||||||
|
w.muEnabled.Unlock()
|
||||||
if w.ctx != nil && w.ctx.Err() == nil {
|
|
||||||
w.log.Errorf("WireGuard watcher already enabled")
|
|
||||||
w.ctxLock.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
w.log.Debugf("enable WireGuard watcher")
|
||||||
w.ctx = ctx
|
enabledTime := time.Now()
|
||||||
w.ctxCancel = ctxCancel
|
w.enabled = true
|
||||||
w.ctxLock.Unlock()
|
w.muEnabled.Unlock()
|
||||||
|
|
||||||
initialHandshake, err := w.wgState()
|
initialHandshake, err := w.wgState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
|
||||||
|
|
||||||
|
w.muEnabled.Lock()
|
||||||
|
w.enabled = false
|
||||||
|
w.muEnabled.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
// IsEnabled returns true if the WireGuard watcher is currently enabled
|
||||||
func (w *WGWatcher) DisableWgWatcher() {
|
func (w *WGWatcher) IsEnabled() bool {
|
||||||
w.ctxLock.Lock()
|
w.muEnabled.RLock()
|
||||||
defer w.ctxLock.Unlock()
|
defer w.muEnabled.RUnlock()
|
||||||
|
return w.enabled
|
||||||
if w.ctxCancel == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.log.Debugf("disable WireGuard watcher")
|
|
||||||
|
|
||||||
w.ctxCancel()
|
|
||||||
w.ctxCancel = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||||
w.log.Infof("WireGuard watcher started")
|
w.log.Infof("WireGuard watcher started")
|
||||||
|
|
||||||
timer := time.NewTimer(wgHandshakeOvertime)
|
timer := time.NewTimer(wgHandshakeOvertime)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
defer ctxCancel()
|
|
||||||
|
|
||||||
lastHandshake := initialHandshake
|
lastHandshake := initialHandshake
|
||||||
|
|
||||||
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if lastHandshake.IsZero() {
|
if lastHandshake.IsZero() {
|
||||||
elapsed := handshake.Sub(w.enabledTime).Seconds()
|
elapsed := calcElapsed(enabledTime, *handshake)
|
||||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
|||||||
|
|
||||||
// the current know handshake did not change
|
// the current know handshake did not change
|
||||||
if handshake.Equal(lastHandshake) {
|
if handshake.Equal(lastHandshake) {
|
||||||
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// in case if the machine is suspended, the handshake time will be in the past
|
// in case if the machine is suspended, the handshake time will be in the past
|
||||||
if handshake.Add(checkPeriod).Before(time.Now()) {
|
if handshake.Add(checkPeriod).Before(time.Now()) {
|
||||||
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// error handling for handshake time in the future
|
// error handling for handshake time in the future
|
||||||
if handshake.After(time.Now()) {
|
if handshake.After(time.Now()) {
|
||||||
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
|
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
|
|||||||
}
|
}
|
||||||
return wgState.LastHandshake, nil
|
return wgState.LastHandshake, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// calcElapsed calculates elapsed time since watcher was enabled.
|
||||||
|
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
|
||||||
|
func calcElapsed(enabledTime, handshake time.Time) float64 {
|
||||||
|
elapsed := handshake.Sub(enabledTime).Seconds()
|
||||||
|
if elapsed < 0 {
|
||||||
|
elapsed = 0
|
||||||
|
}
|
||||||
|
return elapsed
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
|||||||
case <-time.After(10 * time.Second):
|
case <-time.After(10 * time.Second):
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
watcher.DisableWgWatcher()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWGWatcher_ReEnable(t *testing.T) {
|
func TestWGWatcher_ReEnable(t *testing.T) {
|
||||||
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
|||||||
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
watcher.EnableWgWatcher(ctx, func() {})
|
||||||
|
}()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Re-enable with a new context
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
onDisconnected := make(chan struct{}, 1)
|
onDisconnected := make(chan struct{}, 1)
|
||||||
|
|
||||||
go watcher.EnableWgWatcher(ctx, func() {})
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
watcher.DisableWgWatcher()
|
|
||||||
|
|
||||||
go watcher.EnableWgWatcher(ctx, func() {
|
go watcher.EnableWgWatcher(ctx, func() {
|
||||||
onDisconnected <- struct{}{}
|
onDisconnected <- struct{}{}
|
||||||
})
|
})
|
||||||
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
|||||||
case <-time.After(10 * time.Second):
|
case <-time.After(10 * time.Second):
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
watcher.DisableWgWatcher()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||||
LocalIceCandidateType: pair.Local.Type().String(),
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
|
||||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
|
||||||
Relayed: isRelayed(pair),
|
Relayed: isRelayed(pair),
|
||||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||||
}
|
}
|
||||||
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
// wait local endpoint configuration
|
// wait local endpoint configuration
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
addrString := pair.Remote.Address()
|
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
|
||||||
parsed, err := netip.ParseAddr(addrString)
|
|
||||||
if (err == nil) && (parsed.Is6()) {
|
|
||||||
addrString = fmt.Sprintf("[%s]", addrString)
|
|
||||||
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
|
|
||||||
}
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
||||||
return
|
return
|
||||||
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||||
|
sessionID := w.SessionID()
|
||||||
|
stats := agent.GetCandidatePairsStats()
|
||||||
|
localCandidates, _ := agent.GetLocalCandidates()
|
||||||
|
remoteCandidates, _ := agent.GetRemoteCandidates()
|
||||||
|
|
||||||
|
localMap := make(map[string]ice.Candidate)
|
||||||
|
for _, c := range localCandidates {
|
||||||
|
localMap[c.ID()] = c
|
||||||
|
}
|
||||||
|
remoteMap := make(map[string]ice.Candidate)
|
||||||
|
for _, c := range remoteCandidates {
|
||||||
|
remoteMap[c.ID()] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stat := range stats {
|
||||||
|
if stat.State == ice.CandidatePairStateSucceeded {
|
||||||
|
local, lok := localMap[stat.LocalCandidateID]
|
||||||
|
remote, rok := remoteMap[stat.RemoteCandidateID]
|
||||||
|
if !lok || !rok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||||
|
sessionID,
|
||||||
|
local.NetworkType(), local.Type(), local.Address(),
|
||||||
|
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||||
|
stat.CurrentRoundTripTime*1000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
|
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
|
||||||
return func(state ice.ConnectionState) {
|
return func(state ice.ConnectionState) {
|
||||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||||
switch state {
|
switch state {
|
||||||
case ice.ConnectionStateConnected:
|
case ice.ConnectionStateConnected:
|
||||||
w.lastKnownState = ice.ConnectionStateConnected
|
w.lastKnownState = ice.ConnectionStateConnected
|
||||||
|
w.logSuccessfulPaths(agent)
|
||||||
return
|
return
|
||||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
|
|||||||
@@ -30,11 +30,9 @@ type WorkerRelay struct {
|
|||||||
relayLock sync.Mutex
|
relayLock sync.Mutex
|
||||||
|
|
||||||
relaySupportedOnRemotePeer atomic.Bool
|
relaySupportedOnRemotePeer atomic.Bool
|
||||||
|
|
||||||
wgWatcher *WGWatcher
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
|
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
peerCtx: ctx,
|
peerCtx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
|
|||||||
config: config,
|
config: config,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
|
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
|
||||||
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WorkerRelay) DisableWgWatcher() {
|
|
||||||
w.wgWatcher.DisableWgWatcher()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||||
return w.relayManager.RelayInstanceAddress()
|
return w.relayManager.RelayInstanceAddress()
|
||||||
}
|
}
|
||||||
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) onWGDisconnected() {
|
|
||||||
w.relayLock.Lock()
|
|
||||||
_ = w.relayedConn.Close()
|
|
||||||
w.relayLock.Unlock()
|
|
||||||
|
|
||||||
w.conn.onRelayDisconnected()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||||
if !w.relayManager.HasRelayAddress() {
|
if !w.relayManager.HasRelayAddress() {
|
||||||
return false
|
return false
|
||||||
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) onRelayClientDisconnected() {
|
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||||
w.wgWatcher.DisableWgWatcher()
|
|
||||||
go w.conn.onRelayDisconnected()
|
go w.conn.onRelayDisconnected()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
ProviderConfig PKCEAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
|
||||||
type PKCEAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
|
||||||
AuthorizationEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// RedirectURL handles authorization code from IDP manager
|
|
||||||
RedirectURLs []string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// ClientCertPair is used for mTLS authentication to the IDP
|
|
||||||
ClientCertPair *tls.Certificate
|
|
||||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
|
||||||
DisablePromptLogin bool
|
|
||||||
// LoginFlag is used to configure the PKCE flow login behavior
|
|
||||||
LoginFlag common.LoginFlag
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
|
||||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authFlow := PKCEAuthorizationFlow{
|
|
||||||
ProviderConfig: PKCEAuthProviderConfig{
|
|
||||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
|
||||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
|
||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
ClientCertPair: clientCert,
|
|
||||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
|
||||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return authFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.AuthorizationEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
|
||||||
}
|
|
||||||
if config.RedirectURLs == nil {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
262
client/internal/proxy/manager_darwin.go
Normal file
262
client/internal/proxy/manager_darwin.go
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const networksetupPath = "/usr/sbin/networksetup"
|
||||||
|
|
||||||
|
// Manager handles system-wide proxy configuration on macOS.
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
modifiedServices []string
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new proxy manager.
|
||||||
|
func NewManager(stateManager *statemanager.Manager) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
stateManager: stateManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveNetworkServices returns the list of active network services.
|
||||||
|
func GetActiveNetworkServices() ([]string, error) {
|
||||||
|
cmd := exec.Command(networksetupPath, "-listallnetworkservices")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list network services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(out), "\n")
|
||||||
|
var services []string
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" || strings.HasPrefix(line, "*") || strings.Contains(line, "asterisk") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
services = append(services, line)
|
||||||
|
}
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableWebProxy enables web proxy for all active network services.
|
||||||
|
func (m *Manager) EnableWebProxy(host string, port int) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.enabled {
|
||||||
|
log.Debug("web proxy already enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := GetActiveNetworkServices()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var modifiedServices []string
|
||||||
|
for _, service := range services {
|
||||||
|
if err := m.enableProxyForService(service, host, port); err != nil {
|
||||||
|
log.Warnf("enable proxy for %s: %v", service, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modifiedServices = append(modifiedServices, service)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = modifiedServices
|
||||||
|
m.enabled = true
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
log.Infof("enabled web proxy on %d services -> %s:%d", len(modifiedServices), host, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) enableProxyForService(service, host string, port int) error {
|
||||||
|
portStr := fmt.Sprintf("%d", port)
|
||||||
|
|
||||||
|
// Set web proxy (HTTP)
|
||||||
|
cmd := exec.Command(networksetupPath, "-setwebproxy", service, host, portStr)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("set web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable web proxy
|
||||||
|
cmd = exec.Command(networksetupPath, "-setwebproxystate", service, "on")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("enable web proxy state: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set secure web proxy (HTTPS)
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxy", service, host, portStr)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("set secure web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable secure web proxy
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "on")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("enable secure web proxy state: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("enabled proxy for service %s", service)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWebProxy disables web proxy for all modified network services.
|
||||||
|
func (m *Manager) DisableWebProxy() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if !m.enabled {
|
||||||
|
log.Debug("web proxy already disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
services := m.modifiedServices
|
||||||
|
if len(services) == 0 {
|
||||||
|
services, _ = GetActiveNetworkServices()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
if err := m.disableProxyForService(service); err != nil {
|
||||||
|
log.Warnf("disable proxy for %s: %v", service, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = nil
|
||||||
|
m.enabled = false
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
log.Info("disabled web proxy")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) disableProxyForService(service string) error {
|
||||||
|
// Disable web proxy (HTTP)
|
||||||
|
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("disable web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable secure web proxy (HTTPS)
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("disable secure web proxy: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("disabled proxy for service %s", service)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoproxyURL sets the automatic proxy configuration URL (PAC file).
|
||||||
|
func (m *Manager) SetAutoproxyURL(pacURL string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
services, err := GetActiveNetworkServices()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var modifiedServices []string
|
||||||
|
for _, service := range services {
|
||||||
|
cmd := exec.Command(networksetupPath, "-setautoproxyurl", service, pacURL)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("set autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "on")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("enable autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
modifiedServices = append(modifiedServices, service)
|
||||||
|
log.Debugf("set autoproxy URL for %s -> %s", service, pacURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = modifiedServices
|
||||||
|
m.enabled = true
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableAutoproxy disables automatic proxy configuration.
|
||||||
|
func (m *Manager) DisableAutoproxy() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
services := m.modifiedServices
|
||||||
|
if len(services) == 0 {
|
||||||
|
services, _ = GetActiveNetworkServices()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
cmd := exec.Command(networksetupPath, "-setautoproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("disable autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = nil
|
||||||
|
m.enabled = false
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled returns whether the proxy is currently enabled.
|
||||||
|
func (m *Manager) IsEnabled() bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore restores proxy settings from a previous state.
|
||||||
|
func (m *Manager) Restore(services []string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
if err := m.disableProxyForService(service); err != nil {
|
||||||
|
log.Warnf("restore proxy for %s: %v", service, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modifiedServices = nil
|
||||||
|
m.enabled = false
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateState() {
|
||||||
|
if m.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.enabled && len(m.modifiedServices) > 0 {
|
||||||
|
state := &ShutdownState{
|
||||||
|
ModifiedServices: m.modifiedServices,
|
||||||
|
}
|
||||||
|
if err := m.stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("update proxy state: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
log.Debugf("delete proxy state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
45
client/internal/proxy/manager_other.go
Normal file
45
client/internal/proxy/manager_other.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager is a no-op proxy manager for non-macOS platforms.
|
||||||
|
type Manager struct{}
|
||||||
|
|
||||||
|
// NewManager creates a new proxy manager (no-op on non-macOS).
|
||||||
|
func NewManager(_ *statemanager.Manager) *Manager {
|
||||||
|
return &Manager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableWebProxy is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) EnableWebProxy(host string, port int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWebProxy is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) DisableWebProxy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoproxyURL is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) SetAutoproxyURL(pacURL string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableAutoproxy is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) DisableAutoproxy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled always returns false on non-macOS platforms.
|
||||||
|
func (m *Manager) IsEnabled() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore is a no-op on non-macOS platforms.
|
||||||
|
func (m *Manager) Restore(services []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
88
client/internal/proxy/manager_test.go
Normal file
88
client/internal/proxy/manager_test.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetActiveNetworkServices(t *testing.T) {
|
||||||
|
services, err := GetActiveNetworkServices()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, services, "should have at least one network service")
|
||||||
|
|
||||||
|
// Check that services don't contain invalid entries
|
||||||
|
for _, service := range services {
|
||||||
|
assert.NotEmpty(t, service)
|
||||||
|
assert.NotContains(t, service, "*")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_EnableDisableWebProxy(t *testing.T) {
|
||||||
|
// Skip this test in CI as it requires admin privileges
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping proxy test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(nil)
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
assert.False(t, m.IsEnabled())
|
||||||
|
|
||||||
|
// This test would require admin privileges to actually enable the proxy
|
||||||
|
// So we just test the basic state management
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownState_Name(t *testing.T) {
|
||||||
|
state := &ShutdownState{}
|
||||||
|
assert.Equal(t, "proxy_state", state.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownState_Cleanup_EmptyServices(t *testing.T) {
|
||||||
|
state := &ShutdownState{
|
||||||
|
ModifiedServices: []string{},
|
||||||
|
}
|
||||||
|
err := state.Cleanup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
s string
|
||||||
|
substr string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"Enabled: Yes", "Enabled: Yes", true},
|
||||||
|
{"Enabled: No", "Enabled: Yes", false},
|
||||||
|
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", "Enabled: Yes", true},
|
||||||
|
{"", "Enabled: Yes", false},
|
||||||
|
{"Enabled: Yes", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.s+"_"+tt.substr, func(t *testing.T) {
|
||||||
|
got := contains(tt.s, tt.substr)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
output string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"Enabled: Yes\nServer: 127.0.0.1\nPort: 8080", true},
|
||||||
|
{"Enabled: No\nServer: \nPort: 0", false},
|
||||||
|
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", true},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.output, func(t *testing.T) {
|
||||||
|
got := isProxyEnabled(tt.output)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
105
client/internal/proxy/state_darwin.go
Normal file
105
client/internal/proxy/state_darwin.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShutdownState stores proxy state for cleanup on unclean shutdown.
|
||||||
|
type ShutdownState struct {
|
||||||
|
ModifiedServices []string `json:"modified_services"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the state name for persistence.
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "proxy_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup restores proxy settings after an unclean shutdown.
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
if len(s.ModifiedServices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("cleaning up proxy state for %d services", len(s.ModifiedServices))
|
||||||
|
|
||||||
|
for _, service := range s.ModifiedServices {
|
||||||
|
// Disable web proxy (HTTP)
|
||||||
|
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("cleanup web proxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable secure web proxy (HTTPS)
|
||||||
|
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("cleanup secure web proxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable autoproxy
|
||||||
|
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "off")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("cleanup autoproxy for %s: %v, output: %s", service, err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("cleaned up proxy for service %s", service)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterState registers the proxy state with the state manager.
|
||||||
|
func RegisterState(stateManager *statemanager.Manager) {
|
||||||
|
if stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyState returns the current proxy state from the command line.
|
||||||
|
func GetProxyState(service string) (webProxy, secureProxy, autoProxy bool, err error) {
|
||||||
|
// Check web proxy state
|
||||||
|
cmd := exec.Command(networksetupPath, "-getwebproxy", service)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, fmt.Errorf("get web proxy: %w", err)
|
||||||
|
}
|
||||||
|
webProxy = isProxyEnabled(string(out))
|
||||||
|
|
||||||
|
// Check secure web proxy state
|
||||||
|
cmd = exec.Command(networksetupPath, "-getsecurewebproxy", service)
|
||||||
|
out, err = cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, fmt.Errorf("get secure web proxy: %w", err)
|
||||||
|
}
|
||||||
|
secureProxy = isProxyEnabled(string(out))
|
||||||
|
|
||||||
|
// Check autoproxy state
|
||||||
|
cmd = exec.Command(networksetupPath, "-getautoproxyurl", service)
|
||||||
|
out, err = cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, fmt.Errorf("get autoproxy: %w", err)
|
||||||
|
}
|
||||||
|
autoProxy = isProxyEnabled(string(out))
|
||||||
|
|
||||||
|
return webProxy, secureProxy, autoProxy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProxyEnabled(output string) bool {
|
||||||
|
return !contains(output, "Enabled: No") && contains(output, "Enabled: Yes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
24
client/internal/proxy/state_other.go
Normal file
24
client/internal/proxy/state_other.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShutdownState is a no-op state for non-macOS platforms.
|
||||||
|
type ShutdownState struct{}
|
||||||
|
|
||||||
|
// Name returns the state name.
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "proxy_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup is a no-op on non-macOS platforms.
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterState is a no-op on non-macOS platforms.
|
||||||
|
func RegisterState(stateManager *statemanager.Manager) {
|
||||||
|
}
|
||||||
@@ -17,6 +17,11 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultLog = slog.LevelInfo
|
||||||
|
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
|
||||||
|
)
|
||||||
|
|
||||||
func hashRosenpassKey(key []byte) string {
|
func hashRosenpassKey(key []byte) string {
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
hasher.Write(key)
|
hasher.Write(key)
|
||||||
@@ -34,6 +39,7 @@ type Manager struct {
|
|||||||
server *rp.Server
|
server *rp.Server
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
port int
|
port int
|
||||||
|
wgIface PresharedKeySetter
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new Rosenpass manager
|
// NewManager creates a new Rosenpass manager
|
||||||
@@ -44,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
rpKeyHash := hashRosenpassKey(public)
|
rpKeyHash := hashRosenpassKey(public)
|
||||||
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
|
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
|
|||||||
|
|
||||||
func (m *Manager) generateConfig() (rp.Config, error) {
|
func (m *Manager) generateConfig() (rp.Config, error) {
|
||||||
opts := &slog.HandlerOptions{
|
opts := &slog.HandlerOptions{
|
||||||
Level: slog.LevelDebug,
|
Level: getLogLevel(),
|
||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
|
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
|
||||||
cfg := rp.Config{Logger: logger}
|
cfg := rp.Config{Logger: logger}
|
||||||
@@ -109,7 +115,13 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
|||||||
cfg.SecretKey = m.ssk
|
cfg.SecretKey = m.ssk
|
||||||
|
|
||||||
cfg.Peers = []rp.PeerConfig{}
|
cfg.Peers = []rp.PeerConfig{}
|
||||||
m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName)
|
|
||||||
|
m.lock.Lock()
|
||||||
|
m.rpWgHandler = NewNetbirdHandler()
|
||||||
|
if m.wgIface != nil {
|
||||||
|
m.rpWgHandler.SetInterface(m.wgIface)
|
||||||
|
}
|
||||||
|
m.lock.Unlock()
|
||||||
|
|
||||||
cfg.Handlers = []rp.Handler{m.rpWgHandler}
|
cfg.Handlers = []rp.Handler{m.rpWgHandler}
|
||||||
|
|
||||||
@@ -126,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
|||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLogLevel() slog.Level {
|
||||||
|
level, ok := os.LookupEnv(defaultLogLevelVar)
|
||||||
|
if !ok {
|
||||||
|
return defaultLog
|
||||||
|
}
|
||||||
|
switch strings.ToLower(level) {
|
||||||
|
case "debug":
|
||||||
|
return slog.LevelDebug
|
||||||
|
case "info":
|
||||||
|
return slog.LevelInfo
|
||||||
|
case "warn":
|
||||||
|
return slog.LevelWarn
|
||||||
|
case "error":
|
||||||
|
return slog.LevelError
|
||||||
|
default:
|
||||||
|
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
|
||||||
|
return defaultLog
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) OnDisconnected(peerKey string) {
|
func (m *Manager) OnDisconnected(peerKey string) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
defer m.lock.Unlock()
|
defer m.lock.Unlock()
|
||||||
@@ -172,6 +204,20 @@ func (m *Manager) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetInterface sets the WireGuard interface for the rosenpass handler.
|
||||||
|
// This can be called before or after Run() - the interface will be stored
|
||||||
|
// and passed to the handler when it's created or updated immediately if
|
||||||
|
// already running.
|
||||||
|
func (m *Manager) SetInterface(iface PresharedKeySetter) {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
|
m.wgIface = iface
|
||||||
|
if m.rpWgHandler != nil {
|
||||||
|
m.rpWgHandler.SetInterface(iface)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
|
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
|
||||||
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
|
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
@@ -192,6 +238,20 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake
|
||||||
|
// and set a PSK for the given WireGuard peer.
|
||||||
|
func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
|
peerID, ok := m.rpPeerIDs[wireGuardPubKey]
|
||||||
|
if !ok || peerID == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.rpWgHandler.IsPeerInitialized(*peerID)
|
||||||
|
}
|
||||||
|
|
||||||
func findRandomAvailableUDPPort() (int, error) {
|
func findRandomAvailableUDPPort() (int, error) {
|
||||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,46 +1,50 @@
|
|||||||
package rosenpass
|
package rosenpass
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"sync"
|
||||||
"log/slog"
|
|
||||||
|
|
||||||
rp "cunicu.li/go-rosenpass"
|
rp "cunicu.li/go-rosenpass"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers.
|
||||||
|
// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface.
|
||||||
|
type PresharedKeySetter interface {
|
||||||
|
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||||
|
}
|
||||||
|
|
||||||
type wireGuardPeer struct {
|
type wireGuardPeer struct {
|
||||||
Interface string
|
Interface string
|
||||||
PublicKey rp.Key
|
PublicKey rp.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetbirdHandler struct {
|
type NetbirdHandler struct {
|
||||||
ifaceName string
|
mu sync.Mutex
|
||||||
client *wgctrl.Client
|
iface PresharedKeySetter
|
||||||
peers map[rp.PeerID]wireGuardPeer
|
peers map[rp.PeerID]wireGuardPeer
|
||||||
presharedKey [32]byte
|
initializedPeers map[rp.PeerID]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
|
func NewNetbirdHandler() *NetbirdHandler {
|
||||||
hdlr = &NetbirdHandler{
|
return &NetbirdHandler{
|
||||||
ifaceName: wgIfaceName,
|
peers: map[rp.PeerID]wireGuardPeer{},
|
||||||
peers: map[rp.PeerID]wireGuardPeer{},
|
initializedPeers: map[rp.PeerID]bool{},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if preSharedKey != nil {
|
// SetInterface sets the WireGuard interface for the handler.
|
||||||
hdlr.presharedKey = *preSharedKey
|
// This must be called after the WireGuard interface is created.
|
||||||
}
|
func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) {
|
||||||
|
h.mu.Lock()
|
||||||
if hdlr.client, err = wgctrl.New(); err != nil {
|
defer h.mu.Unlock()
|
||||||
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
|
h.iface = iface
|
||||||
}
|
|
||||||
|
|
||||||
return hdlr, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
|
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.peers[pid] = wireGuardPeer{
|
h.peers[pid] = wireGuardPeer{
|
||||||
Interface: intf,
|
Interface: intf,
|
||||||
PublicKey: pk,
|
PublicKey: pk,
|
||||||
@@ -48,79 +52,61 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
|
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
delete(h.peers, pid)
|
delete(h.peers, pid)
|
||||||
|
delete(h.initializedPeers, pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPeerInitialized returns true if Rosenpass has completed a handshake
|
||||||
|
// and set a PSK for this peer.
|
||||||
|
func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
return h.initializedPeers[pid]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
|
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
|
||||||
log.Debug("Handshake complete")
|
|
||||||
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
|
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
|
||||||
key, _ := rp.GeneratePresharedKey()
|
key, _ := rp.GeneratePresharedKey()
|
||||||
log.Debug("Handshake expired")
|
|
||||||
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
|
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
|
||||||
|
h.mu.Lock()
|
||||||
|
iface := h.iface
|
||||||
wg, ok := h.peers[pid]
|
wg, ok := h.peers[pid]
|
||||||
|
isInitialized := h.initializedPeers[pid]
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
if iface == nil {
|
||||||
|
log.Warn("rosenpass: interface not set, cannot update preshared key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := h.client.Device(h.ifaceName)
|
peerKey := wgtypes.Key(wg.PublicKey).String()
|
||||||
if err != nil {
|
pskKey := wgtypes.Key(psk)
|
||||||
log.Errorf("Failed to get WireGuard device: %v", err)
|
|
||||||
|
// Use updateOnly=true for later rotations (peer already has Rosenpass PSK)
|
||||||
|
// Use updateOnly=false for first rotation (peer has original/empty PSK)
|
||||||
|
if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil {
|
||||||
|
log.Errorf("Failed to apply rosenpass key: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
config := []wgtypes.PeerConfig{
|
|
||||||
{
|
|
||||||
UpdateOnly: true,
|
|
||||||
PublicKey: wgtypes.Key(wg.PublicKey),
|
|
||||||
PresharedKey: (*wgtypes.Key)(&psk),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, peer := range device.Peers {
|
|
||||||
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
|
|
||||||
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
|
|
||||||
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
|
|
||||||
config = []wgtypes.PeerConfig{
|
|
||||||
{
|
|
||||||
PublicKey: wgtypes.Key(wg.PublicKey),
|
|
||||||
PresharedKey: (*wgtypes.Key)(&psk),
|
|
||||||
Endpoint: peer.Endpoint,
|
|
||||||
AllowedIPs: peer.AllowedIPs,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{
|
|
||||||
{
|
|
||||||
Remove: true,
|
|
||||||
PublicKey: wgtypes.Key(wg.PublicKey),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug("Failed to remove peer")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Mark peer as isInitialized after the successful first rotation
|
||||||
|
if !isInitialized {
|
||||||
|
h.mu.Lock()
|
||||||
|
if _, exists := h.peers[pid]; exists {
|
||||||
|
h.initializedPeers[pid] = true
|
||||||
}
|
}
|
||||||
}
|
h.mu.Unlock()
|
||||||
|
|
||||||
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
|
|
||||||
Peers: config,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("Failed to apply rosenpass key: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func publicKeyEmpty(key wgtypes.Key) bool {
|
|
||||||
for _, b := range key {
|
|
||||||
if b != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||||
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
|||||||
return false, errors.New("not supported on mobile platforms")
|
return false, errors.New("not supported on mobile platforms")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Interface monitor: skipped in netstack mode")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
if ifaceName == "" {
|
if ifaceName == "" {
|
||||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||||
return false, errors.New("empty interface name")
|
return false, errors.New("empty interface name")
|
||||||
|
|||||||
@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
|
||||||
|
return true // Assume login is required if we can't create auth client
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||||
// If the check fails, assume login is required to be safe
|
// If the check fails, assume login is required to be safe
|
||||||
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
|
|||||||
|
|
||||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||||
go func() {
|
go func() {
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwtToken := tokenInfo.GetTokenToUse()
|
jwtToken := tokenInfo.GetTokenToUse()
|
||||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
|
||||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
// which are blocked by the tvOS sandbox in App Group containers
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
// LoginSync performs a synchronous login check without UI interaction
|
// LoginSync performs a synchronous login check without UI interaction
|
||||||
// Used for background VPN connection where user should already be authenticated
|
// Used for background VPN connection where user should already be authenticated
|
||||||
func (a *Auth) LoginSync() error {
|
func (a *Auth) LoginSync() error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
|
|||||||
return fmt.Errorf("not authenticated")
|
return fmt.Errorf("not authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// PermissionDenied means registration is required or peer is blocked
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return fmt.Errorf("authentication error: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("login failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||||
var needsLogin bool
|
|
||||||
|
|
||||||
// Create context with device name if provided
|
// Create context with device name if provided
|
||||||
ctx := a.ctx
|
ctx := a.ctx
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
|||||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(ctx, func() (err error) {
|
|
||||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
// check if we need to generate JWT token
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(ctx, func() error {
|
err, isAuthError := authClient.Login(ctx, "", jwtToken)
|
||||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// PermissionDenied means registration is required or peer is blocked
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return fmt.Errorf("authentication error: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("login failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
|||||||
|
|
||||||
const authInfoRequestTimeout = 30 * time.Second
|
const authInfoRequestTimeout = 30 * time.Second
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||||
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetConfigJSON returns the current config as a JSON string.
|
// GetConfigJSON returns the current config as a JSON string.
|
||||||
// This can be used by the caller to persist the config via alternative storage
|
// This can be used by the caller to persist the config via alternative storage
|
||||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||||
|
|||||||
76
client/jobexec/executor.go
Normal file
76
client/jobexec/executor.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package jobexec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrJobNotImplemented = errors.New("job not implemented")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Executor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewExecutor() *Executor {
|
||||||
|
return &Executor{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
|
||||||
|
if waitForDuration > MaxBundleWaitTime {
|
||||||
|
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
|
||||||
|
waitForDuration = MaxBundleWaitTime
|
||||||
|
}
|
||||||
|
|
||||||
|
if waitForDuration > 0 {
|
||||||
|
if err := waitFor(ctx, waitForDuration); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("execute debug bundle generation")
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to upload debug bundle: %v", err)
|
||||||
|
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("debug bundle has been generated successfully")
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitFor(ctx context.Context, duration time.Duration) error {
|
||||||
|
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
|
||||||
|
select {
|
||||||
|
case <-time.After(duration):
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("wait cancelled: %v", ctx.Err())
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.36.6
|
// protoc-gen-go v1.36.6
|
||||||
// protoc v3.21.12
|
// protoc v6.32.1
|
||||||
// source: daemon.proto
|
// source: daemon.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
@@ -2757,7 +2757,6 @@ func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule {
|
|||||||
type DebugBundleRequest struct {
|
type DebugBundleRequest struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
|
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
|
||||||
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
|
|
||||||
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
||||||
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
|
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
|
||||||
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
|
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
|
||||||
@@ -2802,13 +2801,6 @@ func (x *DebugBundleRequest) GetAnonymize() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *DebugBundleRequest) GetStatus() string {
|
|
||||||
if x != nil {
|
|
||||||
return x.Status
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x *DebugBundleRequest) GetSystemInfo() bool {
|
func (x *DebugBundleRequest) GetSystemInfo() bool {
|
||||||
if x != nil {
|
if x != nil {
|
||||||
return x.SystemInfo
|
return x.SystemInfo
|
||||||
@@ -5372,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartCPUProfileRequest for starting CPU profiling
|
||||||
|
type StartCPUProfileRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StartCPUProfileRequest) Reset() {
|
||||||
|
*x = StartCPUProfileRequest{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[79]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StartCPUProfileRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StartCPUProfileRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[79]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartCPUProfileResponse confirms CPU profiling has started
|
||||||
|
type StartCPUProfileResponse struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StartCPUProfileResponse) Reset() {
|
||||||
|
*x = StartCPUProfileResponse{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[80]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StartCPUProfileResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StartCPUProfileResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[80]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopCPUProfileRequest for stopping CPU profiling
|
||||||
|
type StopCPUProfileRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StopCPUProfileRequest) Reset() {
|
||||||
|
*x = StopCPUProfileRequest{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[81]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StopCPUProfileRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StopCPUProfileRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[81]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{81}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopCPUProfileResponse confirms CPU profiling has stopped
|
||||||
|
type StopCPUProfileResponse struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StopCPUProfileResponse) Reset() {
|
||||||
|
*x = StopCPUProfileResponse{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[82]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *StopCPUProfileResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StopCPUProfileResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[82]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{82}
|
||||||
|
}
|
||||||
|
|
||||||
type InstallerResultRequest struct {
|
type InstallerResultRequest struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
@@ -5380,7 +5520,7 @@ type InstallerResultRequest struct {
|
|||||||
|
|
||||||
func (x *InstallerResultRequest) Reset() {
|
func (x *InstallerResultRequest) Reset() {
|
||||||
*x = InstallerResultRequest{}
|
*x = InstallerResultRequest{}
|
||||||
mi := &file_daemon_proto_msgTypes[79]
|
mi := &file_daemon_proto_msgTypes[83]
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
ms.StoreMessageInfo(mi)
|
ms.StoreMessageInfo(mi)
|
||||||
}
|
}
|
||||||
@@ -5392,7 +5532,7 @@ func (x *InstallerResultRequest) String() string {
|
|||||||
func (*InstallerResultRequest) ProtoMessage() {}
|
func (*InstallerResultRequest) ProtoMessage() {}
|
||||||
|
|
||||||
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||||
mi := &file_daemon_proto_msgTypes[79]
|
mi := &file_daemon_proto_msgTypes[83]
|
||||||
if x != nil {
|
if x != nil {
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
if ms.LoadMessageInfo() == nil {
|
if ms.LoadMessageInfo() == nil {
|
||||||
@@ -5405,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
|||||||
|
|
||||||
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
|
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
|
||||||
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
|
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
|
||||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
return file_daemon_proto_rawDescGZIP(), []int{83}
|
||||||
}
|
}
|
||||||
|
|
||||||
type InstallerResultResponse struct {
|
type InstallerResultResponse struct {
|
||||||
@@ -5418,7 +5558,7 @@ type InstallerResultResponse struct {
|
|||||||
|
|
||||||
func (x *InstallerResultResponse) Reset() {
|
func (x *InstallerResultResponse) Reset() {
|
||||||
*x = InstallerResultResponse{}
|
*x = InstallerResultResponse{}
|
||||||
mi := &file_daemon_proto_msgTypes[80]
|
mi := &file_daemon_proto_msgTypes[84]
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
ms.StoreMessageInfo(mi)
|
ms.StoreMessageInfo(mi)
|
||||||
}
|
}
|
||||||
@@ -5430,7 +5570,7 @@ func (x *InstallerResultResponse) String() string {
|
|||||||
func (*InstallerResultResponse) ProtoMessage() {}
|
func (*InstallerResultResponse) ProtoMessage() {}
|
||||||
|
|
||||||
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||||
mi := &file_daemon_proto_msgTypes[80]
|
mi := &file_daemon_proto_msgTypes[84]
|
||||||
if x != nil {
|
if x != nil {
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
if ms.LoadMessageInfo() == nil {
|
if ms.LoadMessageInfo() == nil {
|
||||||
@@ -5443,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
|||||||
|
|
||||||
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
|
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
|
||||||
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
|
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
|
||||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
return file_daemon_proto_rawDescGZIP(), []int{84}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *InstallerResultResponse) GetSuccess() bool {
|
func (x *InstallerResultResponse) GetSuccess() bool {
|
||||||
@@ -5470,7 +5610,7 @@ type PortInfo_Range struct {
|
|||||||
|
|
||||||
func (x *PortInfo_Range) Reset() {
|
func (x *PortInfo_Range) Reset() {
|
||||||
*x = PortInfo_Range{}
|
*x = PortInfo_Range{}
|
||||||
mi := &file_daemon_proto_msgTypes[82]
|
mi := &file_daemon_proto_msgTypes[86]
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
ms.StoreMessageInfo(mi)
|
ms.StoreMessageInfo(mi)
|
||||||
}
|
}
|
||||||
@@ -5482,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
|
|||||||
func (*PortInfo_Range) ProtoMessage() {}
|
func (*PortInfo_Range) ProtoMessage() {}
|
||||||
|
|
||||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||||
mi := &file_daemon_proto_msgTypes[82]
|
mi := &file_daemon_proto_msgTypes[86]
|
||||||
if x != nil {
|
if x != nil {
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
if ms.LoadMessageInfo() == nil {
|
if ms.LoadMessageInfo() == nil {
|
||||||
@@ -5773,10 +5913,9 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
|
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
|
||||||
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
|
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
|
||||||
"\x17ForwardingRulesResponse\x12,\n" +
|
"\x17ForwardingRulesResponse\x12,\n" +
|
||||||
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
|
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
|
||||||
"\x12DebugBundleRequest\x12\x1c\n" +
|
"\x12DebugBundleRequest\x12\x1c\n" +
|
||||||
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
|
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
|
||||||
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
|
|
||||||
"\n" +
|
"\n" +
|
||||||
"systemInfo\x18\x03 \x01(\bR\n" +
|
"systemInfo\x18\x03 \x01(\bR\n" +
|
||||||
"systemInfo\x12\x1c\n" +
|
"systemInfo\x12\x1c\n" +
|
||||||
@@ -6003,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
||||||
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
||||||
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
|
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
|
||||||
|
"\x16StartCPUProfileRequest\"\x19\n" +
|
||||||
|
"\x17StartCPUProfileResponse\"\x17\n" +
|
||||||
|
"\x15StopCPUProfileRequest\"\x18\n" +
|
||||||
|
"\x16StopCPUProfileResponse\"\x18\n" +
|
||||||
"\x16InstallerResultRequest\"O\n" +
|
"\x16InstallerResultRequest\"O\n" +
|
||||||
"\x17InstallerResultResponse\x12\x18\n" +
|
"\x17InstallerResultResponse\x12\x18\n" +
|
||||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||||
@@ -6015,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x04WARN\x10\x04\x12\b\n" +
|
"\x04WARN\x10\x04\x12\b\n" +
|
||||||
"\x04INFO\x10\x05\x12\t\n" +
|
"\x04INFO\x10\x05\x12\t\n" +
|
||||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||||
"\x05TRACE\x10\a2\xb4\x13\n" +
|
"\x05TRACE\x10\a2\xdd\x14\n" +
|
||||||
"\rDaemonService\x126\n" +
|
"\rDaemonService\x126\n" +
|
||||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||||
@@ -6050,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
|
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
|
||||||
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
||||||
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
||||||
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
|
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" +
|
||||||
|
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
|
||||||
|
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
|
||||||
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
||||||
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
|
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||||
|
|
||||||
@@ -6067,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
|
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
|
||||||
var file_daemon_proto_goTypes = []any{
|
var file_daemon_proto_goTypes = []any{
|
||||||
(LogLevel)(0), // 0: daemon.LogLevel
|
(LogLevel)(0), // 0: daemon.LogLevel
|
||||||
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
||||||
@@ -6152,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{
|
|||||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
||||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
||||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
||||||
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
|
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
|
||||||
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
|
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
|
||||||
nil, // 85: daemon.Network.ResolvedIPsEntry
|
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
|
||||||
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
|
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
|
||||||
nil, // 87: daemon.SystemEvent.MetadataEntry
|
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
|
||||||
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
|
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
|
||||||
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
|
nil, // 89: daemon.Network.ResolvedIPsEntry
|
||||||
|
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
|
||||||
|
nil, // 91: daemon.SystemEvent.MetadataEntry
|
||||||
|
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
|
||||||
|
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
|
||||||
}
|
}
|
||||||
var file_daemon_proto_depIdxs = []int32{
|
var file_daemon_proto_depIdxs = []int32{
|
||||||
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||||
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||||
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||||
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||||
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||||
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||||
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||||
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||||
@@ -6177,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||||
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||||
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||||
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||||
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||||
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||||
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||||
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||||
@@ -6189,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||||
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||||
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||||
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||||
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||||
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||||
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||||
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||||
@@ -6226,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||||
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||||
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||||
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||||
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||||
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||||
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||||
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||||
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||||
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||||
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||||
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||||
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||||
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||||
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||||
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||||
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||||
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||||
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||||
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||||
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||||
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||||
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||||
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||||
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||||
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||||
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||||
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||||
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||||
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||||
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||||
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||||
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||||
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||||
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||||
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||||
67, // [67:100] is the sub-list for method output_type
|
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||||
34, // [34:67] is the sub-list for method input_type
|
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||||
|
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||||
|
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||||
|
69, // [69:104] is the sub-list for method output_type
|
||||||
|
34, // [34:69] is the sub-list for method input_type
|
||||||
34, // [34:34] is the sub-list for extension type_name
|
34, // [34:34] is the sub-list for extension type_name
|
||||||
34, // [34:34] is the sub-list for extension extendee
|
34, // [34:34] is the sub-list for extension extendee
|
||||||
0, // [0:34] is the sub-list for field type_name
|
0, // [0:34] is the sub-list for field type_name
|
||||||
@@ -6292,7 +6445,7 @@ func file_daemon_proto_init() {
|
|||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||||
NumEnums: 4,
|
NumEnums: 4,
|
||||||
NumMessages: 84,
|
NumMessages: 88,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 1,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -94,6 +94,12 @@ service DaemonService {
|
|||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||||
|
|
||||||
|
// StartCPUProfile starts CPU profiling in the daemon
|
||||||
|
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
|
||||||
|
|
||||||
|
// StopCPUProfile stops CPU profiling in the daemon
|
||||||
|
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
||||||
|
|
||||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
|
|
||||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||||
@@ -455,7 +461,6 @@ message ForwardingRulesResponse {
|
|||||||
// DebugBundler
|
// DebugBundler
|
||||||
message DebugBundleRequest {
|
message DebugBundleRequest {
|
||||||
bool anonymize = 1;
|
bool anonymize = 1;
|
||||||
string status = 2;
|
|
||||||
bool systemInfo = 3;
|
bool systemInfo = 3;
|
||||||
string uploadURL = 4;
|
string uploadURL = 4;
|
||||||
uint32 logFileCount = 5;
|
uint32 logFileCount = 5;
|
||||||
@@ -777,6 +782,18 @@ message WaitJWTTokenResponse {
|
|||||||
int64 expiresIn = 3;
|
int64 expiresIn = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartCPUProfileRequest for starting CPU profiling
|
||||||
|
message StartCPUProfileRequest {}
|
||||||
|
|
||||||
|
// StartCPUProfileResponse confirms CPU profiling has started
|
||||||
|
message StartCPUProfileResponse {}
|
||||||
|
|
||||||
|
// StopCPUProfileRequest for stopping CPU profiling
|
||||||
|
message StopCPUProfileRequest {}
|
||||||
|
|
||||||
|
// StopCPUProfileResponse confirms CPU profiling has stopped
|
||||||
|
message StopCPUProfileResponse {}
|
||||||
|
|
||||||
message InstallerResultRequest {
|
message InstallerResultRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ type DaemonServiceClient interface {
|
|||||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||||
|
// StartCPUProfile starts CPU profiling in the daemon
|
||||||
|
StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error)
|
||||||
|
// StopCPUProfile stops CPU profiling in the daemon
|
||||||
|
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
|
||||||
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||||
}
|
}
|
||||||
@@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) {
|
||||||
|
out := new(StartCPUProfileResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) {
|
||||||
|
out := new(StopCPUProfileResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
|
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
|
||||||
out := new(OSLifecycleResponse)
|
out := new(OSLifecycleResponse)
|
||||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
|
||||||
@@ -458,6 +480,10 @@ type DaemonServiceServer interface {
|
|||||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||||
|
// StartCPUProfile starts CPU profiling in the daemon
|
||||||
|
StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error)
|
||||||
|
// StopCPUProfile stops CPU profiling in the daemon
|
||||||
|
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
|
||||||
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
@@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
|
|||||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
||||||
}
|
}
|
||||||
@@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(StartCPUProfileRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).StartCPUProfile(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/StartCPUProfile",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(StopCPUProfileRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).StopCPUProfile(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/StopCPUProfile",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(OSLifecycleRequest)
|
in := new(OSLifecycleRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
@@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "WaitJWTToken",
|
MethodName: "WaitJWTToken",
|
||||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "StartCPUProfile",
|
||||||
|
Handler: _DaemonService_StartCPUProfile_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "StopCPUProfile",
|
||||||
|
Handler: _DaemonService_StopCPUProfile_Handler,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
MethodName: "NotifyOSLifecycle",
|
MethodName: "NotifyOSLifecycle",
|
||||||
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
||||||
|
|||||||
@@ -3,25 +3,19 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"runtime/pprof"
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
|
||||||
|
|
||||||
// DebugBundle creates a debug bundle and returns the location.
|
// DebugBundle creates a debug bundle and returns the location.
|
||||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
@@ -32,16 +26,37 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
log.Warnf("failed to get latest sync response: %v", err)
|
log.Warnf("failed to get latest sync response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cpuProfileData []byte
|
||||||
|
if s.cpuProfileBuf != nil && !s.cpuProfiling {
|
||||||
|
cpuProfileData = s.cpuProfileBuf.Bytes()
|
||||||
|
defer func() {
|
||||||
|
s.cpuProfileBuf = nil
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare refresh callback for health probes
|
||||||
|
var refreshStatus func()
|
||||||
|
if s.connectClient != nil {
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine != nil {
|
||||||
|
refreshStatus = func() {
|
||||||
|
log.Debug("refreshing system health status for debug bundle")
|
||||||
|
engine.RunHealthProbes(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bundleGenerator := debug.NewBundleGenerator(
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
debug.GeneratorDependencies{
|
debug.GeneratorDependencies{
|
||||||
InternalConfig: s.config,
|
InternalConfig: s.config,
|
||||||
StatusRecorder: s.statusRecorder,
|
StatusRecorder: s.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogFile: s.logFile,
|
LogPath: s.logFile,
|
||||||
|
CPUProfile: cpuProfileData,
|
||||||
|
RefreshStatus: refreshStatus,
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
Anonymize: req.GetAnonymize(),
|
Anonymize: req.GetAnonymize(),
|
||||||
ClientStatus: req.GetStatus(),
|
|
||||||
IncludeSystemInfo: req.GetSystemInfo(),
|
IncludeSystemInfo: req.GetSystemInfo(),
|
||||||
LogFileCount: req.GetLogFileCount(),
|
LogFileCount: req.GetLogFileCount(),
|
||||||
},
|
},
|
||||||
@@ -55,7 +70,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
if req.GetUploadURL() == "" {
|
if req.GetUploadURL() == "" {
|
||||||
return &proto.DebugBundleResponse{Path: path}, nil
|
return &proto.DebugBundleResponse{Path: path}, nil
|
||||||
}
|
}
|
||||||
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
||||||
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
||||||
@@ -66,92 +81,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
|
||||||
response, err := getUploadURL(ctx, url, managementURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = upload(ctx, filePath, response)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return response.Key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
|
||||||
fileData, err := os.Open(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer fileData.Close()
|
|
||||||
|
|
||||||
stat, err := fileData.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stat file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if stat.Size() > maxBundleUploadSize {
|
|
||||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create PUT request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.ContentLength = stat.Size()
|
|
||||||
req.Header.Set("Content-Type", "application/octet-stream")
|
|
||||||
|
|
||||||
putResp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("upload failed: %v", err)
|
|
||||||
}
|
|
||||||
defer putResp.Body.Close()
|
|
||||||
|
|
||||||
if putResp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(putResp.Body)
|
|
||||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
|
||||||
id := getURLHash(managementURL)
|
|
||||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create GET request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(getReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
urlBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response body: %w", err)
|
|
||||||
}
|
|
||||||
var response types.GetURLResponse
|
|
||||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
|
||||||
}
|
|
||||||
return &response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getURLHash(url string) string {
|
|
||||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogLevel gets the current logging level for the server.
|
// GetLogLevel gets the current logging level for the server.
|
||||||
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
@@ -204,3 +133,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
|||||||
|
|
||||||
return cClient.GetLatestSyncResponse()
|
return cClient.GetLatestSyncResponse()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartCPUProfile starts CPU profiling in the daemon.
|
||||||
|
func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.cpuProfiling {
|
||||||
|
return nil, fmt.Errorf("CPU profiling already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cpuProfileBuf = &bytes.Buffer{}
|
||||||
|
s.cpuProfiling = true
|
||||||
|
if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil {
|
||||||
|
s.cpuProfileBuf = nil
|
||||||
|
s.cpuProfiling = false
|
||||||
|
return nil, fmt.Errorf("start CPU profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("CPU profiling started")
|
||||||
|
return &proto.StartCPUProfileResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopCPUProfile stops CPU profiling in the daemon.
|
||||||
|
func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if !s.cpuProfiling {
|
||||||
|
return nil, fmt.Errorf("CPU profiling not in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
pprof.StopCPUProfile()
|
||||||
|
s.cpuProfiling = false
|
||||||
|
|
||||||
|
if s.cpuProfileBuf != nil {
|
||||||
|
log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.StopCPUProfileResponse{}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,9 +14,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
@@ -67,7 +67,7 @@ type Server struct {
|
|||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
clientRunning bool // protected by mutex
|
clientRunning bool // protected by mutex
|
||||||
clientRunningChan chan struct{}
|
clientRunningChan chan struct{}
|
||||||
clientGiveUpChan chan struct{}
|
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||||
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
|
|
||||||
@@ -78,6 +78,9 @@ type Server struct {
|
|||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
isSessionActive atomic.Bool
|
isSessionActive atomic.Bool
|
||||||
|
|
||||||
|
cpuProfileBuf *bytes.Buffer
|
||||||
|
cpuProfiling bool
|
||||||
|
|
||||||
profileManager *profilemanager.ServiceManager
|
profileManager *profilemanager.ServiceManager
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
@@ -250,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
|||||||
|
|
||||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||||
var status internal.StatusType
|
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
log.Errorf("failed to create auth client: %v", err)
|
||||||
|
return internal.StatusLoginFailed, err
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
var status internal.StatusType
|
||||||
|
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
log.Warnf("failed login: %v", err)
|
log.Warnf("failed login: %v", err)
|
||||||
status = internal.StatusNeedsLogin
|
status = internal.StatusNeedsLogin
|
||||||
} else {
|
} else {
|
||||||
@@ -578,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
s.oauthAuthFlow.waitCancel()
|
s.oauthAuthFlow.waitCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
waitCTX, cancel := context.WithCancel(ctx)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
@@ -793,9 +802,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
|||||||
// Down engine work in the daemon.
|
// Down engine work in the daemon.
|
||||||
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
giveUpChan := s.clientGiveUpChan
|
||||||
|
|
||||||
if err := s.cleanupConnection(); err != nil {
|
if err := s.cleanupConnection(); err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
// todo review to update the status in case any type of error
|
// todo review to update the status in case any type of error
|
||||||
log.Errorf("failed to shut down properly: %v", err)
|
log.Errorf("failed to shut down properly: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -804,6 +815,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
|||||||
state := internal.CtxGetState(s.rootCtx)
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
state.Set(internal.StatusIdle)
|
state.Set(internal.StatusIdle)
|
||||||
|
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
|
||||||
|
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
|
||||||
|
// The giveUpChan is closed at the end of connectWithRetryRuns.
|
||||||
|
if giveUpChan != nil {
|
||||||
|
select {
|
||||||
|
case <-giveUpChan:
|
||||||
|
log.Debugf("client goroutine finished successfully")
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.DownResponse{}, nil
|
return &proto.DownResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1308,6 +1333,10 @@ func (s *Server) runProbes(waitForProbeResult bool) {
|
|||||||
if engine.RunHealthProbes(waitForProbeResult) {
|
if engine.RunHealthProbes(waitForProbeResult) {
|
||||||
s.lastProbe = time.Now()
|
s.lastProbe = time.Now()
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if err := s.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||||
|
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1521,7 +1550,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
|||||||
log.Tracef("running client connection")
|
log.Tracef("running client connection")
|
||||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
|
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
|
||||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||||
if err := s.connectClient.Run(runningChan); err != nil {
|
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -306,6 +307,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
@@ -317,7 +320,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -326,7 +329,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||||
// Create a backend session to mirror the client's session request.
|
|
||||||
// This keeps the connection alive on the server side while port forwarding channels operate.
|
|
||||||
serverSession, err := sshClient.NewSession()
|
serverSession, err := sshClient.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||||
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
|
|||||||
}
|
}
|
||||||
defer func() { _ = serverSession.Close() }()
|
defer func() { _ = serverSession.Close() }()
|
||||||
|
|
||||||
<-session.Context().Done()
|
serverSession.Stdin = session
|
||||||
|
serverSession.Stdout = session
|
||||||
|
serverSession.Stderr = session.Stderr()
|
||||||
|
|
||||||
if err := session.Exit(0); err != nil {
|
if err := serverSession.Shell(); err != nil {
|
||||||
log.Debugf("session exit: %v", err)
|
log.Debugf("start shell: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- serverSession.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-session.Context().Done():
|
||||||
|
return
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("shell session: %v", err)
|
||||||
|
p.handleProxyExitCode(session, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleCommand executes an SSH command with privilege validation
|
// handleExecution executes an SSH command or shell with privilege validation
|
||||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
|
||||||
hasPty := winCh != nil
|
hasPty := winCh != nil
|
||||||
|
|
||||||
commandType := "command"
|
commandType := "command"
|
||||||
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||||
|
|
||||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||||
|
|
||||||
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
ptyReq, _, _ := session.Pty()
|
|
||||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||||
logger.Debugf("%s execution completed", commandType)
|
logger.Debugf("%s execution completed", commandType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
return nil, nil, errors.New("no user in privilege result")
|
return nil, nil, errors.New("no user in privilege result")
|
||||||
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
|
|||||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||||
if hasPty && !s.suSupportsPty {
|
if hasPty && !s.suSupportsPty {
|
||||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, cleanup, nil
|
return cmd, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try su first for system integration (PAM/audit) when privileged
|
// Try su first for system integration (PAM/audit) when privileged
|
||||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil || privilegeResult.UsedFallback {
|
if err != nil || privilegeResult.UsedFallback {
|
||||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
logger.Debugf("su command failed, falling back to executor: %v", err)
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, cleanup, nil
|
return cmd, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, func() {}, nil
|
return cmd, func() {}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,17 +15,17 @@ import (
|
|||||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||||
|
|
||||||
// createSuCommand is not supported on JS/WASM
|
// createSuCommand is not supported on JS/WASM
|
||||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||||
return nil, errNotSupported
|
return nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// createExecutorCommand is not supported on JS/WASM
|
// createExecutorCommand is not supported on JS/WASM
|
||||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||||
return nil, nil, errNotSupported
|
return nil, nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv is not supported on JS/WASM
|
// prepareCommandEnv is not supported on JS/WASM
|
||||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
|||||||
return isUtilLinux
|
return isUtilLinux
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching
|
// createSuCommand creates a command using su - for privilege switching.
|
||||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
|
if err := validateUsername(localUser.Username); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||||
|
}
|
||||||
|
|
||||||
suPath, err := exec.LookPath("su")
|
suPath, err := exec.LookPath("su")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("su command not available: %w", err)
|
return nil, fmt.Errorf("su command not available: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
command := session.RawCommand()
|
args := []string{"-"}
|
||||||
if command == "" {
|
|
||||||
return nil, fmt.Errorf("no command specified for su execution")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{"-l"}
|
|
||||||
if hasPty && s.suSupportsPty {
|
if hasPty && s.suSupportsPty {
|
||||||
args = append(args, "--pty")
|
args = append(args, "--pty")
|
||||||
}
|
}
|
||||||
args = append(args, localUser.Username, "-c", command)
|
args = append(args, localUser.Username)
|
||||||
|
|
||||||
|
command := session.RawCommand()
|
||||||
|
if command != "" {
|
||||||
|
args = append(args, "-c", command)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("creating su command: %s %v", suPath, args)
|
||||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||||
cmd.Dir = localUser.HomeDir
|
cmd.Dir = localUser.HomeDir
|
||||||
|
|
||||||
return cmd, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
// getShellCommandArgs returns the shell command and arguments for executing a command string.
|
||||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||||
if cmdString == "" {
|
if cmdString == "" {
|
||||||
return []string{shell, "-l"}
|
return []string{shell}
|
||||||
}
|
}
|
||||||
return []string{shell, "-l", "-c", cmdString}
|
return []string{shell, "-c", cmdString}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
|
||||||
|
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
|
||||||
|
cmd := exec.CommandContext(ctx, shell, args[1:]...)
|
||||||
|
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||||
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
env = append(env, prepareSSHEnv(session)...)
|
env = append(env, prepareSSHEnv(session)...)
|
||||||
for _, v := range session.Environ() {
|
for _, v := range session.Environ() {
|
||||||
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
|
|||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Pty command creation failed: %v", err)
|
logger.Errorf("Pty command creation failed: %v", err)
|
||||||
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
|
||||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
logger.Debugf("session close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if _, err := io.Copy(session, ptmx); err != nil {
|
if _, err := io.Copy(session, ptmx); err != nil {
|
||||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||||
logger.Warnf("Pty output copy error: %v", err)
|
logger.Warnf("Pty output copy error: %v", err)
|
||||||
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
s.handlePtyCommandCompletion(logger, session, err)
|
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugf("Pty command execution failed: %v", err)
|
logger.Debugf("Pty command execution failed: %v", err)
|
||||||
s.handleSessionExit(session, err, logger)
|
s.handleSessionExit(session, err, logger)
|
||||||
return
|
} else {
|
||||||
|
logger.Debugf("Pty command completed successfully")
|
||||||
|
if err := session.Exit(0); err != nil {
|
||||||
|
logSessionExitError(logger, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal completion
|
// Close PTY to unblock io.Copy goroutines
|
||||||
logger.Debugf("Pty command completed successfully")
|
if err := ptyMgr.Close(); err != nil {
|
||||||
if err := session.Exit(0); err != nil {
|
logger.Debugf("Pty close after completion: %v", err)
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,32 +20,32 @@ import (
|
|||||||
|
|
||||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
|
||||||
userToken, err := s.getUserToken(username, domain)
|
userToken, err := s.getUserToken(logger, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get user token: %w", err)
|
return nil, fmt.Errorf("get user token: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(userToken); err != nil {
|
if err := windows.CloseHandle(userToken); err != nil {
|
||||||
log.Debugf("close user token: %v", err)
|
logger.Debugf("close user token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
|
||||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||||
}
|
}
|
||||||
|
|
||||||
envMap := make(map[string]string)
|
envMap := make(map[string]string)
|
||||||
|
|
||||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||||
log.Debugf("failed to load system environment from registry: %v", err)
|
logger.Debugf("failed to load system environment from registry: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||||
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getUserToken creates a user token for the specified user.
|
// getUserToken creates a user token for the specified user.
|
||||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
token, err := privilegeDropper.createToken(username, domain)
|
token, err := privilegeDropper.createToken(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||||
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
userEnv, err := s.getUserEnvironment(username, domain)
|
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
|
|||||||
return env
|
return env
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
if privilegeResult.User == nil {
|
if privilegeResult.User == nil {
|
||||||
logger.Errorf("no user in privilege result")
|
logger.Errorf("no user in privilege result")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := session.Command()
|
|
||||||
shell := getUserShell(privilegeResult.User.Uid)
|
shell := getUserShell(privilegeResult.User.Uid)
|
||||||
|
logger.Infof("starting interactive shell: %s", shell)
|
||||||
|
|
||||||
if len(cmd) == 0 {
|
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||||
logger.Infof("starting interactive shell: %s", shell)
|
|
||||||
} else {
|
|
||||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
|||||||
return []string{shell, "-Command", cmdString}
|
return []string{shell, "-Command", cmdString}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
|
||||||
logger.Info("starting interactive shell")
|
|
||||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
|
||||||
}
|
|
||||||
|
|
||||||
type PtyExecutionRequest struct {
|
type PtyExecutionRequest struct {
|
||||||
Shell string
|
Shell string
|
||||||
Command string
|
Command string
|
||||||
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
|
|||||||
Domain string
|
Domain string
|
||||||
}
|
}
|
||||||
|
|
||||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
|
||||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||||
|
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create user token: %w", err)
|
return fmt.Errorf("create user token: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(userToken); err != nil {
|
if err := windows.CloseHandle(userToken); err != nil {
|
||||||
log.Debugf("close user token: %v", err)
|
logger.Debugf("close user token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server := &Server{}
|
server := &Server{}
|
||||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||||
userEnv = os.Environ()
|
userEnv = os.Environ()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
|
|||||||
Environment: userEnv,
|
Environment: userEnv,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserHomeFromEnv(env []string) string {
|
func getUserHomeFromEnv(env []string) string {
|
||||||
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := log.WithField("pid", cmd.Process.Pid)
|
|
||||||
|
|
||||||
if err := cmd.Process.Kill(); err != nil {
|
if err := cmd.Process.Kill(); err != nil {
|
||||||
logger.Debugf("kill process failed: %v", err)
|
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
command := session.RawCommand()
|
|
||||||
if command == "" {
|
|
||||||
logger.Error("no command specified for PTY execution")
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
|
||||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
logger.Errorf("no user in privilege result")
|
logger.Errorf("no user in privilege result")
|
||||||
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
|
|||||||
|
|
||||||
req := PtyExecutionRequest{
|
req := PtyExecutionRequest{
|
||||||
Shell: shell,
|
Shell: shell,
|
||||||
Command: command,
|
Command: session.RawCommand(),
|
||||||
Width: ptyReq.Window.Width,
|
Width: ptyReq.Window.Width,
|
||||||
Height: ptyReq.Window.Height,
|
Height: ptyReq.Window.Height,
|
||||||
Username: username,
|
Username: username,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
|
||||||
logger.Errorf("ConPty execution failed: %v", err)
|
logger.Errorf("ConPty execution failed: %v", err)
|
||||||
if err := session.Exit(1); err != nil {
|
if err := session.Exit(1); err != nil {
|
||||||
logSessionExitError(logger, err)
|
logSessionExitError(logger, err)
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,25 +26,67 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMain handles package-level setup and cleanup
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||||
// This happens when running tests as non-privileged user with fallback
|
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||||
|
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||||
|
// propagate exit codes.
|
||||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||||
// Just exit with error to break the recursion
|
runTestExecutor()
|
||||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
return
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run tests
|
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
|
||||||
// Cleanup any created test users
|
|
||||||
testutil.CleanupTestUsers()
|
testutil.CleanupTestUsers()
|
||||||
|
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runTestExecutor emulates the netbird executor for tests.
|
||||||
|
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||||
|
func runTestExecutor() {
|
||||||
|
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||||
|
|
||||||
|
shell := "/bin/sh"
|
||||||
|
var command string
|
||||||
|
for i := 3; i < len(os.Args); i++ {
|
||||||
|
switch os.Args[i] {
|
||||||
|
case "--shell":
|
||||||
|
if i+1 < len(os.Args) {
|
||||||
|
shell = os.Args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--cmd":
|
||||||
|
if i+1 < len(os.Args) {
|
||||||
|
command = os.Args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if command == "" {
|
||||||
|
cmd = exec.Command(shell)
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(shell, "-c", command)
|
||||||
|
}
|
||||||
|
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
os.Exit(exitErr.ExitCode())
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||||
func TestSSHServerCompatibility(t *testing.T) {
|
func TestSSHServerCompatibility(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
|||||||
return createTempKeyFileFromBytes(t, privateKey)
|
return createTempKeyFileFromBytes(t, privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||||
|
// This ensures our implementation matches OpenSSH behavior for:
|
||||||
|
// - ssh host command (no PTY - default when no TTY)
|
||||||
|
// - ssh -T host command (explicit no PTY)
|
||||||
|
// - ssh -t host command (force PTY)
|
||||||
|
// - ssh -T host (no PTY shell - our implementation)
|
||||||
|
func TestSSHPtyModes(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSSHClientAvailable() {
|
||||||
|
t.Skip("SSH client not available on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
|
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||||
|
}
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||||
|
defer cleanupKey()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
|
baseArgs := []string{
|
||||||
|
"-i", clientKeyFile,
|
||||||
|
"-p", portStr,
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
|
"-o", "ConnectTimeout=5",
|
||||||
|
"-o", "BatchMode=yes",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "no_pty_default")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "explicit_no_pty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command_force_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "force_pty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||||
|
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||||
|
|
||||||
|
stdin, err := cmd.StdinPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer stdin.Close()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||||
|
assert.NoError(t, err, "write echo command")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
_, err = stdin.Write([]byte("exit 0\n"))
|
||||||
|
assert.NoError(t, err, "write exit command")
|
||||||
|
}()
|
||||||
|
|
||||||
|
output, _ := io.ReadAll(stdout)
|
||||||
|
err = cmd.Wait()
|
||||||
|
|
||||||
|
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.Error(t, err, "Command should exit with non-zero")
|
||||||
|
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||||
|
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.Error(t, err, "PTY command should exit with non-zero")
|
||||||
|
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||||
|
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||||
|
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
var stdout, stderr strings.Builder
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||||
|
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||||
|
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||||
|
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||||
|
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "stdout_msg")
|
||||||
|
assert.Contains(t, string(output), "stderr_msg")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||||
type PrivilegeDropper struct{}
|
type PrivilegeDropper struct {
|
||||||
|
logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||||
|
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||||
|
|
||||||
// NewPrivilegeDropper creates a new privilege dropper
|
// NewPrivilegeDropper creates a new privilege dropper
|
||||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||||
return &PrivilegeDropper{}
|
pd := &PrivilegeDropper{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pd)
|
||||||
|
}
|
||||||
|
return pd
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the PrivilegeDropper
|
||||||
|
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||||
|
return func(pd *PrivilegeDropper) {
|
||||||
|
pd.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log returns the logger, falling back to standard logger if none set
|
||||||
|
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||||
|
if pd.logger != nil {
|
||||||
|
return pd.logger
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||||
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
|
|||||||
|
|
||||||
var execCmd *exec.Cmd
|
var execCmd *exec.Cmd
|
||||||
if config.Command == "" {
|
if config.Command == "" {
|
||||||
os.Exit(ExitCodeSuccess)
|
execCmd = exec.CommandContext(ctx, config.Shell)
|
||||||
|
} else {
|
||||||
|
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||||
}
|
}
|
||||||
|
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
|
||||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
|
||||||
execCmd.Stdin = os.Stdin
|
execCmd.Stdin = os.Stdin
|
||||||
execCmd.Stdout = os.Stdout
|
execCmd.Stdout = os.Stdout
|
||||||
execCmd.Stderr = os.Stderr
|
execCmd.Stderr = os.Stderr
|
||||||
|
|
||||||
cmdParts := strings.Fields(config.Command)
|
if config.Command == "" {
|
||||||
safeCmd := safeLogCommand(cmdParts)
|
log.Tracef("executing login shell: %s", execCmd.Path)
|
||||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
} else {
|
||||||
|
cmdParts := strings.Fields(config.Command)
|
||||||
|
safeCmd := safeLogCommand(cmdParts)
|
||||||
|
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||||
|
}
|
||||||
if err := execCmd.Run(); err != nil {
|
if err := execCmd.Run(); err != nil {
|
||||||
var exitError *exec.ExitError
|
var exitError *exec.ExitError
|
||||||
if errors.As(err, &exitError) {
|
if errors.As(err, &exitError) {
|
||||||
|
|||||||
@@ -28,22 +28,45 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WindowsExecutorConfig struct {
|
type WindowsExecutorConfig struct {
|
||||||
Username string
|
Username string
|
||||||
Domain string
|
Domain string
|
||||||
WorkingDir string
|
WorkingDir string
|
||||||
Shell string
|
Shell string
|
||||||
Command string
|
Command string
|
||||||
Args []string
|
Args []string
|
||||||
Interactive bool
|
Pty bool
|
||||||
Pty bool
|
PtyWidth int
|
||||||
PtyWidth int
|
PtyHeight int
|
||||||
PtyHeight int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PrivilegeDropper struct{}
|
type PrivilegeDropper struct {
|
||||||
|
logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||||
return &PrivilegeDropper{}
|
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||||
|
|
||||||
|
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||||
|
pd := &PrivilegeDropper{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pd)
|
||||||
|
}
|
||||||
|
return pd
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the PrivilegeDropper
|
||||||
|
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||||
|
return func(pd *PrivilegeDropper) {
|
||||||
|
pd.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log returns the logger, falling back to standard logger if none set
|
||||||
|
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||||
|
if pd.logger != nil {
|
||||||
|
return pd.logger
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -56,7 +79,6 @@ const (
|
|||||||
|
|
||||||
// Common error messages
|
// Common error messages
|
||||||
commandFlag = "-Command"
|
commandFlag = "-Command"
|
||||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
|
||||||
convertUsernameError = "convert username to UTF16: %w"
|
convertUsernameError = "convert username to UTF16: %w"
|
||||||
convertDomainError = "convert domain to UTF16: %w"
|
convertDomainError = "convert domain to UTF16: %w"
|
||||||
)
|
)
|
||||||
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
|||||||
shellArgs = []string{shell}
|
shellArgs = []string{shell}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||||
|
|
||||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||||
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
|
|||||||
|
|
||||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
userCpn := buildUserCpn(username, domain)
|
userCpn := buildUserCpn(username, domain)
|
||||||
|
|
||||||
pd := NewPrivilegeDropper()
|
pd := NewPrivilegeDropper(WithLogger(logger))
|
||||||
isDomainUser := !pd.isLocalUser(domain)
|
isDomainUser := !pd.isLocalUser(domain)
|
||||||
|
|
||||||
lsaHandle, err := initializeLsaConnection()
|
lsaHandle, err := initializeLsaConnection()
|
||||||
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUserCpn constructs the user principal name
|
// buildUserCpn constructs the user principal name
|
||||||
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||||
if isDomainUser {
|
if isDomainUser {
|
||||||
return prepareDomainS4ULogon(username, domain)
|
return prepareDomainS4ULogon(logger, username, domain)
|
||||||
}
|
}
|
||||||
return prepareLocalS4ULogon(username)
|
return prepareLocalS4ULogon(logger, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||||
upn, err := lookupPrincipalName(username, domain)
|
upn, err := lookupPrincipalName(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||||
|
|
||||||
upnUtf16, err := windows.UTF16FromString(upn)
|
upnUtf16, err := windows.UTF16FromString(upn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
|
||||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||||
|
|
||||||
usernameUtf16, err := windows.UTF16FromString(username)
|
usernameUtf16, err := windows.UTF16FromString(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// performS4ULogon executes the S4U logon operation
|
// performS4ULogon executes the S4U logon operation
|
||||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||||
var tokenSource tokenSource
|
var tokenSource tokenSource
|
||||||
copy(tokenSource.SourceName[:], "netbird")
|
copy(tokenSource.SourceName[:], "netbird")
|
||||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||||
log.Debugf("AllocateLocallyUniqueId failed")
|
logger.Debugf("AllocateLocallyUniqueId failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
originName := newLsaString("netbird")
|
originName := newLsaString("netbird")
|
||||||
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
|||||||
|
|
||||||
if profile != 0 {
|
if profile != 0 {
|
||||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
|||||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("created S4U %s token for user %s",
|
logger.Debugf("created S4U %s token for user %s",
|
||||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
|||||||
|
|
||||||
// authenticateLocalUser handles authentication for local users
|
// authenticateLocalUser handles authentication for local users
|
||||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
|
||||||
token, err := generateS4UUserToken(username, ".")
|
token, err := generateS4UUserToken(pd.log(), username, ".")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||||
}
|
}
|
||||||
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
|
|||||||
|
|
||||||
// authenticateDomainUser handles authentication for domain users
|
// authenticateDomainUser handles authentication for domain users
|
||||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||||
token, err := generateS4UUserToken(username, domain)
|
token, err := generateS4UUserToken(pd.log(), username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||||
}
|
}
|
||||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(token); err != nil {
|
if err := windows.CloseHandle(token); err != nil {
|
||||||
log.Debugf("close impersonation token: %v", err)
|
pd.log().Debugf("close impersonation token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
|||||||
return cmd, primaryToken, nil
|
return cmd, primaryToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
// createSuCommand creates a command using su - for privilege switching (Windows stub).
|
||||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||||
return nil, fmt.Errorf("su command not available on Windows")
|
return nil, fmt.Errorf("su command not available on Windows")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
|
|||||||
server.SetAllowRootLogin(true)
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
|
|||||||
serverNoJWT.SetAllowRootLogin(true)
|
serverNoJWT.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||||
defer require.NoError(t, serverNoJWT.Stop())
|
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
|
||||||
|
|
||||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
|
|||||||
server.SetAllowRootLogin(true)
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
|
|||||||
server.SetAllowRootLogin(true)
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
|
|||||||
server.UpdateSSHAuth(authConfig)
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
|
|||||||
server.UpdateSSHAuth(authConfig)
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
|||||||
return s.allowRemotePortForwarding
|
return s.allowRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
|
||||||
func (s *Server) isPortForwardingEnabled() bool {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTcpipForwardRequest parses the SSH request payload
|
// parseTcpipForwardRequest parses the SSH request payload
|
||||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||||
var payload tcpipForwardMsg
|
var payload tcpipForwardMsg
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
|||||||
sessions = append(sessions, info)
|
sessions = append(sessions, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
|
||||||
for key, connState := range s.connections {
|
for key, connState := range s.connections {
|
||||||
remoteAddr := string(key)
|
remoteAddr := string(key)
|
||||||
if reportedAddrs[remoteAddr] {
|
if reportedAddrs[remoteAddr] {
|
||||||
|
|||||||
@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
func TestServer_NonPtyShellSession(t *testing.T) {
|
||||||
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
require.NoError(t, err, "Should be able to get current user")
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
|
||||||
// Generate host key for server
|
|
||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
allowLocalForwarding bool
|
allowLocalForwarding bool
|
||||||
allowRemoteForwarding bool
|
allowRemoteForwarding bool
|
||||||
expectAllowed bool
|
|
||||||
description string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_local_forwarding",
|
name: "shell_with_local_forwarding_enabled",
|
||||||
allowLocalForwarding: true,
|
allowLocalForwarding: true,
|
||||||
allowRemoteForwarding: false,
|
allowRemoteForwarding: false,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_remote_forwarding",
|
name: "shell_with_remote_forwarding_enabled",
|
||||||
allowLocalForwarding: false,
|
allowLocalForwarding: false,
|
||||||
allowRemoteForwarding: true,
|
allowRemoteForwarding: true,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_both",
|
name: "shell_with_both_forwarding_enabled",
|
||||||
allowLocalForwarding: true,
|
allowLocalForwarding: true,
|
||||||
allowRemoteForwarding: true,
|
allowRemoteForwarding: true,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_denied_without_forwarding",
|
name: "shell_with_forwarding_disabled",
|
||||||
allowLocalForwarding: false,
|
allowLocalForwarding: false,
|
||||||
allowRemoteForwarding: false,
|
allowRemoteForwarding: false,
|
||||||
expectAllowed: false,
|
|
||||||
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
_ = server.Stop()
|
_ = server.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Connect to the server without requesting PTY or command
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
_ = client.Close()
|
_ = client.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Execute a command without PTY - this simulates ssh -T with no command
|
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
|
||||||
// The server should either allow it (port forwarding enabled) or reject it
|
// Should always succeed regardless of port forwarding settings
|
||||||
output, err := client.ExecuteCommand(ctx, "")
|
_, err = client.ExecuteCommand(ctx, "")
|
||||||
if tt.expectAllowed {
|
assert.NoError(t, err, "Non-PTY shell session should be allowed")
|
||||||
// When allowed, the session stays open until cancelled
|
|
||||||
// ExecuteCommand with empty command should return without error
|
|
||||||
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
|
||||||
assert.NotContains(t, output, "port forwarding is disabled",
|
|
||||||
"Output should not contain port forwarding disabled message")
|
|
||||||
} else if err != nil {
|
|
||||||
// When denied, we expect an error message about port forwarding being disabled
|
|
||||||
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
|
||||||
"Should get port forwarding disabled message")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
|||||||
assert.Equal(t, "-Command", args[1])
|
assert.Equal(t, "-Command", args[1])
|
||||||
assert.Equal(t, "echo test", args[2])
|
assert.Equal(t, "echo test", args[2])
|
||||||
} else {
|
} else {
|
||||||
// Test Unix shell behavior
|
|
||||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||||
assert.Equal(t, "/bin/sh", args[0])
|
assert.Equal(t, "/bin/sh", args[0])
|
||||||
assert.Equal(t, "-l", args[1])
|
assert.Equal(t, "-c", args[1])
|
||||||
assert.Equal(t, "-c", args[2])
|
assert.Equal(t, "echo test", args[2])
|
||||||
assert.Equal(t, "echo test", args[3])
|
|
||||||
|
args = server.getShellCommandArgs("/bin/sh", "")
|
||||||
|
assert.Equal(t, "/bin/sh", args[0])
|
||||||
|
assert.Len(t, args, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
|||||||
ptyReq, winCh, isPty := session.Pty()
|
ptyReq, winCh, isPty := session.Pty()
|
||||||
hasCommand := len(session.Command()) > 0
|
hasCommand := len(session.Command()) > 0
|
||||||
|
|
||||||
switch {
|
if isPty && !hasCommand {
|
||||||
case isPty && hasCommand:
|
// ssh <host> - PTY interactive session (login)
|
||||||
// ssh -t <host> <cmd> - Pty command execution
|
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
|
||||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
} else {
|
||||||
case isPty:
|
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
|
||||||
// ssh <host> - Pty interactive session (login)
|
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
|
||||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
|
||||||
case hasCommand:
|
|
||||||
// ssh <host> <cmd> - non-Pty command execution
|
|
||||||
s.handleCommand(logger, session, privilegeResult, nil)
|
|
||||||
default:
|
|
||||||
// ssh -T (or ssh -N) - no PTY, no command
|
|
||||||
s.handleNonInteractiveSession(logger, session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
|
||||||
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
|
||||||
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
|
||||||
s.updateSessionType(session, cmdNonInteractive)
|
|
||||||
|
|
||||||
if !s.isPortForwardingEnabled() {
|
|
||||||
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
|
||||||
logger.Debugf(errWriteSession, err)
|
|
||||||
}
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
<-session.Context().Done()
|
|
||||||
|
|
||||||
if err := session.Exit(0); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
for _, state := range s.sessions {
|
|
||||||
if state.session == session {
|
|
||||||
state.sessionType = sessionType
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handlePty is not supported on JS/WASM
|
// handlePtyLogin is not supported on JS/WASM
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||||
logger.Debugf(errWriteSession, err)
|
logger.Debugf(errWriteSession, err)
|
||||||
|
|||||||
@@ -8,19 +8,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StartTestServer starts the SSH server and returns the address it's listening on.
|
||||||
func StartTestServer(t *testing.T, server *Server) string {
|
func StartTestServer(t *testing.T, server *Server) string {
|
||||||
started := make(chan string, 1)
|
started := make(chan string, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Use port 0 to let the OS assign a free port
|
|
||||||
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual listening address from the server
|
|
||||||
actualAddr := server.Addr()
|
actualAddr := server.Addr()
|
||||||
if actualAddr == nil {
|
if actualAddr == nil {
|
||||||
errChan <- fmt.Errorf("server started but no listener address available")
|
errChan <- fmt.Errorf("server started but no listener address available")
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user