mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-02 22:19:54 +00:00
Compare commits
23 Commits
feat/dev_v
...
fix/bundle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d0f49ae29 | ||
|
|
702d45b7e5 | ||
|
|
cfc81b8fcc | ||
|
|
90aa0cc36a | ||
|
|
cdd31364b2 | ||
|
|
a5a0bf6ff4 | ||
|
|
bcda5eddbb | ||
|
|
b9a7375f64 | ||
|
|
471e2f98d7 | ||
|
|
97b8c53dff | ||
|
|
702552e9dd | ||
|
|
5a3301b3c7 | ||
|
|
8c50979468 | ||
|
|
1e66db8ddb | ||
|
|
16d1a4d550 | ||
|
|
7ea5e37dd4 | ||
|
|
9d7ef9b255 | ||
|
|
944a258459 | ||
|
|
1f9a829f2c | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 |
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -100,6 +101,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
|
CliVersion: version.NetbirdVersion(),
|
||||||
}
|
}
|
||||||
if uploadBundleFlag {
|
if uploadBundleFlag {
|
||||||
request.UploadURL = uploadBundleURLFlag
|
request.UploadURL = uploadBundleURLFlag
|
||||||
@@ -298,6 +300,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
|
CliVersion: version.NetbirdVersion(),
|
||||||
}
|
}
|
||||||
if uploadBundleFlag {
|
if uploadBundleFlag {
|
||||||
request.UploadURL = uploadBundleURLFlag
|
request.UploadURL = uploadBundleURLFlag
|
||||||
@@ -432,6 +435,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
|||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: logFilePath,
|
LogPath: logFilePath,
|
||||||
CPUProfile: nil,
|
CPUProfile: nil,
|
||||||
|
DaemonVersion: version.NetbirdVersion(), // acting as daemon
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
IncludeSystemInfo: true,
|
IncludeSystemInfo: true,
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Common setup for service control commands
|
// Common setup for service control commands
|
||||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc, consoleLog bool) (service.Service, error) {
|
||||||
// rootCmd env vars are already applied by PersistentPreRunE.
|
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||||
SetFlagsFromEnvVars(serviceCmd)
|
SetFlagsFromEnvVars(serviceCmd)
|
||||||
|
|
||||||
@@ -112,8 +112,14 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
if consoleLog {
|
||||||
return nil, fmt.Errorf("init log: %w", err)
|
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||||
|
return nil, fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
||||||
|
return nil, fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
cfg, err := newSVCConfig()
|
||||||
@@ -138,7 +144,7 @@ var runCmd = &cobra.Command{
|
|||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
|
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
|
||||||
|
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -152,7 +158,7 @@ var startCmd = &cobra.Command{
|
|||||||
Short: "starts NetBird service",
|
Short: "starts NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -170,7 +176,7 @@ var stopCmd = &cobra.Command{
|
|||||||
Short: "stops NetBird service",
|
Short: "stops NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -188,7 +194,7 @@ var restartCmd = &cobra.Command{
|
|||||||
Short: "restarts NetBird service",
|
Short: "restarts NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -206,7 +212,7 @@ var svcStatusCmd = &cobra.Command{
|
|||||||
Short: "shows NetBird service status",
|
Short: "shows NetBird service status",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@@ -100,6 +101,26 @@ type Options struct {
|
|||||||
MTU *uint16
|
MTU *uint16
|
||||||
// DNSLabels defines additional DNS labels configured in the peer.
|
// DNSLabels defines additional DNS labels configured in the peer.
|
||||||
DNSLabels []string
|
DNSLabels []string
|
||||||
|
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||||
|
Performance Performance
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||||
|
//
|
||||||
|
// These settings are process-global: any non-nil field also becomes the
|
||||||
|
// default for Clients constructed by later embed.New calls in the same
|
||||||
|
// process. Nil fields are ignored.
|
||||||
|
type Performance struct {
|
||||||
|
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||||
|
// leaves the pool unbounded. Lower values trade throughput for a
|
||||||
|
// tighter memory ceiling. May also be changed on a running Client via
|
||||||
|
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||||
|
PreallocatedBuffersPerPool *uint32
|
||||||
|
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||||
|
// writes per syscall, which also bounds eager buffer allocation per
|
||||||
|
// worker. Zero uses the platform default. Applied at construction
|
||||||
|
// only; ignored by Client.SetPerformance.
|
||||||
|
MaxBatchSize *uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -199,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
|||||||
config.PrivateKey = opts.PrivateKey
|
config.PrivateKey = opts.PrivateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||||
|
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||||
|
}
|
||||||
|
if opts.Performance.MaxBatchSize != nil {
|
||||||
|
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
deviceName: opts.DeviceName,
|
deviceName: opts.DeviceName,
|
||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
@@ -495,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
|||||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||||
|
// takes effect, and only when it was nonzero at construction;
|
||||||
|
// MaxBatchSize is construction-only and returns an error if set here.
|
||||||
|
//
|
||||||
|
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||||
|
// running yet.
|
||||||
|
func (c *Client) SetPerformance(t Performance) error {
|
||||||
|
if t.MaxBatchSize != nil {
|
||||||
|
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||||
|
}
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return engine.SetPerformance(internal.Performance{
|
||||||
|
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// StartCapture begins capturing packets on this client's tunnel device.
|
// StartCapture begins capturing packets on this client's tunnel device.
|
||||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||||
|
|||||||
@@ -254,6 +254,8 @@ type BundleGenerator struct {
|
|||||||
capturePath string
|
capturePath string
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
clientMetrics MetricsExporter
|
clientMetrics MetricsExporter
|
||||||
|
daemonVersion string
|
||||||
|
cliVersion string
|
||||||
|
|
||||||
anonymize bool
|
anonymize bool
|
||||||
includeSystemInfo bool
|
includeSystemInfo bool
|
||||||
@@ -278,6 +280,8 @@ type GeneratorDependencies struct {
|
|||||||
CapturePath string
|
CapturePath string
|
||||||
RefreshStatus func()
|
RefreshStatus func()
|
||||||
ClientMetrics MetricsExporter
|
ClientMetrics MetricsExporter
|
||||||
|
DaemonVersion string
|
||||||
|
CliVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||||
@@ -299,6 +303,8 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
capturePath: deps.CapturePath,
|
capturePath: deps.CapturePath,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
clientMetrics: deps.ClientMetrics,
|
clientMetrics: deps.ClientMetrics,
|
||||||
|
daemonVersion: deps.DaemonVersion,
|
||||||
|
cliVersion: deps.CliVersion,
|
||||||
|
|
||||||
anonymize: cfg.Anonymize,
|
anonymize: cfg.Anonymize,
|
||||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||||
@@ -459,9 +465,11 @@ func (g *BundleGenerator) addStatus() error {
|
|||||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
|
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
|
||||||
Anonymize: g.anonymize,
|
Anonymize: g.anonymize,
|
||||||
ProfileName: profName,
|
ProfileName: profName,
|
||||||
|
DaemonVersion: g.daemonVersion,
|
||||||
})
|
})
|
||||||
|
overview.CliVersion = g.cliVersion
|
||||||
statusOutput := overview.FullDetailSummary()
|
statusOutput := overview.FullDetailSummary()
|
||||||
|
|
||||||
statusReader := strings.NewReader(statusOutput)
|
statusReader := strings.NewReader(statusOutput)
|
||||||
@@ -1039,7 +1047,8 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pattern := filepath.Join(logDir, "client-*.log.gz")
|
// This regex will match both logs rotated by us and logrotate on linux
|
||||||
|
pattern := filepath.Join(logDir, "client*.log.*")
|
||||||
files, err := filepath.Glob(pattern)
|
files, err := filepath.Glob(pattern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to glob rotated logs: %v", err)
|
log.Warnf("failed to glob rotated logs: %v", err)
|
||||||
@@ -1072,7 +1081,12 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
|
|||||||
|
|
||||||
for i := 0; i < maxFiles; i++ {
|
for i := 0; i < maxFiles; i++ {
|
||||||
name := filepath.Base(files[i])
|
name := filepath.Base(files[i])
|
||||||
if err := g.addSingleLogFileGz(files[i], name); err != nil {
|
if strings.HasSuffix(name, ".gz") {
|
||||||
|
err = g.addSingleLogFileGz(files[i], name)
|
||||||
|
} else {
|
||||||
|
err = g.addSingleLogfile(files[i], name)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
log.Warnf("failed to add rotated log %s: %v", name, err)
|
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
103
client/internal/debug/debug_logfiles_test.go
Normal file
103
client/internal/debug/debug_logfiles_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/zip"
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddRotatedLogFiles_PicksUpAllVariants asserts that the rotated-log
|
||||||
|
// glob picks up logs rotated by timberjack (gzipped) and by logrotate (plain
|
||||||
|
// and gzipped), and skips unrelated files.
|
||||||
|
func TestAddRotatedLogFiles_PicksUpAllVariants(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
writeFile(t, filepath.Join(dir, "client.log"), "active log\n")
|
||||||
|
writeFile(t, filepath.Join(dir, "other.log"), "unrelated\n")
|
||||||
|
|
||||||
|
timberjackRotated := "client-2026-05-21T10-30-45.000.log.gz"
|
||||||
|
writeGzFile(t, filepath.Join(dir, timberjackRotated), "timberjack rotated content\n")
|
||||||
|
|
||||||
|
logrotatePlain := "client.log.1"
|
||||||
|
writeFile(t, filepath.Join(dir, logrotatePlain), "logrotate plain content\n")
|
||||||
|
|
||||||
|
logrotateGz := "client.log.2.gz"
|
||||||
|
writeGzFile(t, filepath.Join(dir, logrotateGz), "logrotate gz content\n")
|
||||||
|
|
||||||
|
names := runAddRotatedLogFiles(t, dir, 10)
|
||||||
|
|
||||||
|
require.Contains(t, names, timberjackRotated, "timberjack rotated file should be in bundle")
|
||||||
|
require.Contains(t, names, logrotatePlain, "logrotate plain rotated file should be in bundle")
|
||||||
|
require.Contains(t, names, logrotateGz, "logrotate gzipped rotated file should be in bundle")
|
||||||
|
require.NotContains(t, names, "client.log", "active log should not be added by addRotatedLogFiles")
|
||||||
|
require.NotContains(t, names, "other.log", "unrelated files should not be in bundle")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddRotatedLogFiles_RespectsLogFileCount asserts that only the newest
|
||||||
|
// logFileCount rotated files are bundled, ordered by mtime.
|
||||||
|
func TestAddRotatedLogFiles_RespectsLogFileCount(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
oldest := filepath.Join(dir, "client.log.3")
|
||||||
|
middle := filepath.Join(dir, "client.log.2")
|
||||||
|
newest := filepath.Join(dir, "client.log.1")
|
||||||
|
writeFile(t, oldest, "old\n")
|
||||||
|
writeFile(t, middle, "mid\n")
|
||||||
|
writeFile(t, newest, "new\n")
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
require.NoError(t, os.Chtimes(oldest, now.Add(-2*time.Hour), now.Add(-2*time.Hour)))
|
||||||
|
require.NoError(t, os.Chtimes(middle, now.Add(-1*time.Hour), now.Add(-1*time.Hour)))
|
||||||
|
require.NoError(t, os.Chtimes(newest, now, now))
|
||||||
|
|
||||||
|
names := runAddRotatedLogFiles(t, dir, 2)
|
||||||
|
|
||||||
|
require.Contains(t, names, "client.log.1")
|
||||||
|
require.Contains(t, names, "client.log.2")
|
||||||
|
require.NotContains(t, names, "client.log.3", "oldest file should be dropped when logFileCount=2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// runAddRotatedLogFiles calls addRotatedLogFiles against a fresh in-memory
|
||||||
|
// zip writer and returns the set of entry names that ended up in the archive.
|
||||||
|
func runAddRotatedLogFiles(t *testing.T, dir string, logFileCount uint32) map[string]struct{} {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
g := &BundleGenerator{
|
||||||
|
archive: zip.NewWriter(&buf),
|
||||||
|
logFileCount: logFileCount,
|
||||||
|
}
|
||||||
|
g.addRotatedLogFiles(dir)
|
||||||
|
require.NoError(t, g.archive.Close())
|
||||||
|
|
||||||
|
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
names := make(map[string]struct{}, len(zr.File))
|
||||||
|
for _, f := range zr.File {
|
||||||
|
names[f.Name] = struct{}{}
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(t *testing.T, path, content string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte(content), 0o644))
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeGzFile(t *testing.T, path, content string) {
|
||||||
|
t.Helper()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&buf)
|
||||||
|
_, err := io.WriteString(gw, content)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, gw.Close())
|
||||||
|
require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o644))
|
||||||
|
}
|
||||||
@@ -72,6 +72,7 @@ import (
|
|||||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
"github.com/netbirdio/netbird/util/capture"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||||
@@ -1141,6 +1142,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
|||||||
LogPath: e.config.LogPath,
|
LogPath: e.config.LogPath,
|
||||||
TempDir: e.config.TempDir,
|
TempDir: e.config.TempDir,
|
||||||
ClientMetrics: e.clientMetrics,
|
ClientMetrics: e.clientMetrics,
|
||||||
|
DaemonVersion: version.NetbirdVersion(),
|
||||||
RefreshStatus: func() {
|
RefreshStatus: func() {
|
||||||
e.RunHealthProbes(true)
|
e.RunHealthProbes(true)
|
||||||
},
|
},
|
||||||
@@ -1967,6 +1969,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
|||||||
return e.clientMetrics
|
return e.clientMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||||
|
// See Engine.SetPerformance. Nil fields are ignored.
|
||||||
|
type Performance struct {
|
||||||
|
PreallocatedBuffersPerPool *uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPerformance applies the given tuning to this engine's live Device.
|
||||||
|
func (e *Engine) SetPerformance(t Performance) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return fmt.Errorf("wg interface not initialized")
|
||||||
|
}
|
||||||
|
dev := e.wgInterface.GetWGDevice()
|
||||||
|
if dev == nil {
|
||||||
|
return fmt.Errorf("wg device not initialized")
|
||||||
|
}
|
||||||
|
if t.PreallocatedBuffersPerPool != nil {
|
||||||
|
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
@@ -66,8 +66,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to deterministic key if no NetBird PSK is configured
|
// Fallback to deterministic key if no NetBird PSK is configured
|
||||||
determKey, err := conn.rosenpassDetermKey()
|
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
return determKey
|
return determKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: move this logic into Rosenpass package
|
|
||||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
|
||||||
lk := []byte(conn.config.LocalKey)
|
|
||||||
rk := []byte(conn.config.Key) // remote key
|
|
||||||
var keyInput []byte
|
|
||||||
if string(lk) > string(rk) {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(lk[:16], rk[:16]...)
|
|
||||||
} else {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(rk[:16], lk[:16]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := wgtypes.NewKey(keyInput)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
|
|||||||
return hex.EncodeToString(hasher.Sum(nil))
|
return hex.EncodeToString(hasher.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
|
||||||
|
// so tests can substitute a mock without spinning up a real UDP server.
|
||||||
|
type rpServer interface {
|
||||||
|
AddPeer(rp.PeerConfig) (rp.PeerID, error)
|
||||||
|
RemovePeer(rp.PeerID) error
|
||||||
|
Run() error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
ifaceName string
|
ifaceName string
|
||||||
spk []byte
|
spk []byte
|
||||||
@@ -36,7 +45,7 @@ type Manager struct {
|
|||||||
preSharedKey *[32]byte
|
preSharedKey *[32]byte
|
||||||
rpPeerIDs map[string]*rp.PeerID
|
rpPeerIDs map[string]*rp.PeerID
|
||||||
rpWgHandler *NetbirdHandler
|
rpWgHandler *NetbirdHandler
|
||||||
server *rp.Server
|
server rpServer
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
port int
|
port int
|
||||||
wgIface PresharedKeySetter
|
wgIface PresharedKeySetter
|
||||||
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
|||||||
|
|
||||||
rpKeyHash := hashRosenpassKey(public)
|
rpKeyHash := hashRosenpassKey(public)
|
||||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
return &Manager{
|
||||||
|
ifaceName: wgIfaceName,
|
||||||
|
rpKeyHash: rpKeyHash,
|
||||||
|
spk: public,
|
||||||
|
ssk: secret,
|
||||||
|
preSharedKey: (*[32]byte)(preSharedKey),
|
||||||
|
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||||
|
// rpWgHandler is created here (instead of only in generateConfig) so it
|
||||||
|
// is never nil between NewManager and Run(). Otherwise an early
|
||||||
|
// OnConnected call (race observed on Android, issue #4341) panics on
|
||||||
|
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
|
||||||
|
// replace it with a fresh handler on each Run() to clear stale peer
|
||||||
|
// state from previous engine sessions.
|
||||||
|
rpWgHandler: NewNetbirdHandler(),
|
||||||
|
lock: sync.Mutex{},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetPubKey() []byte {
|
func (m *Manager) GetPubKey() []byte {
|
||||||
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
|
|||||||
|
|
||||||
// addPeer adds a new peer to the Rosenpass server
|
// addPeer adds a new peer to the Rosenpass server
|
||||||
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
||||||
|
// Defense in depth against issue #4341 (Android crash): if Run() has not
|
||||||
|
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
|
||||||
|
// error instead of panicking on nil-receiver dereference.
|
||||||
|
if m.server == nil {
|
||||||
|
return fmt.Errorf("rosenpass server not initialized")
|
||||||
|
}
|
||||||
|
if m.rpWgHandler == nil {
|
||||||
|
return fmt.Errorf("rosenpass wg handler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
||||||
if m.preSharedKey != nil {
|
if m.preSharedKey != nil {
|
||||||
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
|
|||||||
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
||||||
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
||||||
}
|
}
|
||||||
|
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
|
||||||
|
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
|
||||||
|
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
|
||||||
|
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
|
||||||
|
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
|
||||||
|
// IPv6 so its address family matches our listening socket.
|
||||||
|
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
|
||||||
|
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
|
||||||
|
pcfg.Endpoint.IP = v4.To16()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
peerID, err := m.server.AddPeer(pcfg)
|
peerID, err := m.server.AddPeer(pcfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.server, err = rp.NewUDPServer(conf)
|
server, err := rp.NewUDPServer(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
m.server = server
|
||||||
|
m.lock.Unlock()
|
||||||
|
|
||||||
log.Infof("starting rosenpass server on port %d", m.port)
|
log.Infof("starting rosenpass server on port %d", m.port)
|
||||||
|
|
||||||
return m.server.Run()
|
return server.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the Rosenpass server
|
// Close closes the Rosenpass server
|
||||||
func (m *Manager) Close() error {
|
func (m *Manager) Close() error {
|
||||||
if m.server != nil {
|
m.lock.Lock()
|
||||||
err := m.server.Close()
|
server := m.server
|
||||||
if err != nil {
|
m.server = nil
|
||||||
log.Errorf("failed closing local rosenpass server")
|
m.lock.Unlock()
|
||||||
}
|
if server == nil {
|
||||||
m.server = nil
|
return nil
|
||||||
|
}
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
log.Errorf("failed closing local rosenpass server: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,412 @@
|
|||||||
package rosenpass
|
package rosenpass
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
rp "cunicu.li/go-rosenpass"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// --- test doubles -----------------------------------------------------------
|
||||||
|
|
||||||
|
type addPeerCall struct {
|
||||||
|
cfg rp.PeerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type removePeerCall struct {
|
||||||
|
id rp.PeerID
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockServer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
addCalls []addPeerCall
|
||||||
|
removed []removePeerCall
|
||||||
|
nextID rp.PeerID
|
||||||
|
addErr error
|
||||||
|
removeErr error
|
||||||
|
closed bool
|
||||||
|
ran bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
|
||||||
|
if m.addErr != nil {
|
||||||
|
return rp.PeerID{}, m.addErr
|
||||||
|
}
|
||||||
|
// Increment a byte in nextID so distinct peers get distinct IDs.
|
||||||
|
m.nextID[0]++
|
||||||
|
return m.nextID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.removed = append(m.removed, removePeerCall{id: id})
|
||||||
|
return m.removeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||||
|
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||||
|
|
||||||
|
type setPSKCall struct {
|
||||||
|
peerKey string
|
||||||
|
psk wgtypes.Key
|
||||||
|
updateOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockIface struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls []setPSKCall
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestManager builds a Manager with deterministic spk so tie-break
|
||||||
|
// against a peer pubkey is controllable from tests. The provided spk byte
|
||||||
|
// becomes the first byte; remaining bytes are zero.
|
||||||
|
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
|
||||||
|
spk := make([]byte, 32)
|
||||||
|
spk[0] = spkFirstByte
|
||||||
|
return &Manager{
|
||||||
|
ifaceName: "wt0",
|
||||||
|
spk: spk,
|
||||||
|
ssk: make([]byte, 32),
|
||||||
|
rpKeyHash: "test-hash",
|
||||||
|
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||||
|
rpWgHandler: NewNetbirdHandler(),
|
||||||
|
server: mock,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
|
||||||
|
func validWGKey(t *testing.T, lastByte byte) string {
|
||||||
|
t.Helper()
|
||||||
|
var k wgtypes.Key
|
||||||
|
k[31] = lastByte
|
||||||
|
return k.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- pure helpers ----------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHashRosenpassKey_Deterministic(t *testing.T) {
|
||||||
|
key := []byte("hello-rosenpass")
|
||||||
|
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
|
||||||
|
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
|
||||||
|
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
|
||||||
|
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
|
||||||
|
// can only set, not delete, so do it manually with restore via t.Cleanup.
|
||||||
|
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
|
||||||
|
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if hadPrev {
|
||||||
|
_ = os.Setenv(defaultLogLevelVar, prev)
|
||||||
|
} else {
|
||||||
|
_ = os.Unsetenv(defaultLogLevelVar)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
require.Equal(t, defaultLog.String(), getLogLevel().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogLevel_Cases(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"debug": "DEBUG",
|
||||||
|
"info": "INFO",
|
||||||
|
"warn": "WARN",
|
||||||
|
"error": "ERROR",
|
||||||
|
"unknown": "INFO", // default fallback
|
||||||
|
}
|
||||||
|
for input, wantStr := range cases {
|
||||||
|
input, wantStr := input, wantStr
|
||||||
|
t.Run(input, func(t *testing.T) {
|
||||||
|
t.Setenv(defaultLogLevelVar, input)
|
||||||
|
require.Equal(t, wantStr, getLogLevel().String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||||
port, err := findRandomAvailableUDPPort()
|
port, err := findRandomAvailableUDPPort()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Greater(t, port, 0)
|
require.Greater(t, port, 0)
|
||||||
require.LessOrEqual(t, port, 65535)
|
require.LessOrEqual(t, port, 65535)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- addPeer ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv) // local spk lexicographically larger
|
||||||
|
|
||||||
|
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
|
||||||
|
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, srv.addCalls, 1)
|
||||||
|
|
||||||
|
ep := srv.addCalls[0].cfg.Endpoint
|
||||||
|
require.NotNil(t, ep, "initiator side must set Endpoint")
|
||||||
|
require.Equal(t, 7000, ep.Port)
|
||||||
|
require.Equal(t, "100.1.1.1", ep.IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
|
||||||
|
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
|
||||||
|
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ep := srv.addCalls[0].cfg.Endpoint
|
||||||
|
require.NotNil(t, ep)
|
||||||
|
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
|
||||||
|
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0x00, srv) // local spk smaller
|
||||||
|
|
||||||
|
remotePubKey := make([]byte, 32)
|
||||||
|
remotePubKey[0] = 0xFF
|
||||||
|
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
psk := &wgtypes.Key{0x42}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
m.preSharedKey = (*[32]byte)(psk)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_ServerError_Propagates(t *testing.T) {
|
||||||
|
srv := &mockServer{addErr: errors.New("boom")}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression guard for issue #4341 (Android crash). If Run() has not completed
|
||||||
|
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
|
||||||
|
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
|
||||||
|
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
m.rpWgHandler = nil // simulate Run() not yet completed
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "wg handler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
|
||||||
|
m := newTestManager(0xFF, nil)
|
||||||
|
m.server = nil // simulate Run() not yet completed
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "server not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
|
||||||
|
// issue #4341 cannot occur in the window between NewManager and Run().
|
||||||
|
func TestNewManager_PreInitializesHandler(t *testing.T) {
|
||||||
|
psk := wgtypes.Key{}
|
||||||
|
m, err := NewManager(&psk, "wt0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_RecordsPeerID(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 5)
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OnConnected / OnDisconnected ------------------------------------------
|
||||||
|
|
||||||
|
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
|
||||||
|
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
|
||||||
|
require.Empty(t, m.rpPeerIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 1)
|
||||||
|
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
|
||||||
|
require.Len(t, srv.addCalls, 1)
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
m.OnDisconnected(validWGKey(t, 99))
|
||||||
|
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 1)
|
||||||
|
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
|
||||||
|
m.OnDisconnected(wgKey)
|
||||||
|
require.Len(t, srv.removed, 1)
|
||||||
|
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- IsPresharedKeyInitialized ---------------------------------------------
|
||||||
|
|
||||||
|
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 2)
|
||||||
|
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||||
|
require.False(t, m.IsPresharedKeyInitialized(wgKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- NetbirdHandler.outputKey ----------------------------------------------
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x01}
|
||||||
|
wgKey := wgtypes.Key{0xAA}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||||
|
|
||||||
|
psk := rp.Key{0xBB}
|
||||||
|
h.HandshakeCompleted(pid, psk)
|
||||||
|
|
||||||
|
require.Len(t, iface.calls, 1)
|
||||||
|
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
|
||||||
|
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x02}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
|
||||||
|
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
|
||||||
|
|
||||||
|
require.Len(t, iface.calls, 2)
|
||||||
|
require.False(t, iface.calls[0].updateOnly)
|
||||||
|
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
// no SetInterface — iface remains nil
|
||||||
|
pid := rp.PeerID{0x03}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
|
||||||
|
|
||||||
|
// Must not panic.
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
|
||||||
|
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x04}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x01})
|
||||||
|
require.True(t, h.IsPeerInitialized(pid))
|
||||||
|
|
||||||
|
h.RemovePeer(pid)
|
||||||
|
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
pid := rp.PeerID{0x05}
|
||||||
|
wgKey := wgtypes.Key{0xEE}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||||
|
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface) // set after AddPeer
|
||||||
|
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x42})
|
||||||
|
require.Len(t, iface.calls, 1)
|
||||||
|
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||||
|
}
|
||||||
|
|||||||
42
client/internal/rosenpass/seed.go
Normal file
42
client/internal/rosenpass/seed.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
|
||||||
|
// of peer public keys. Both peers, given the same key pair, produce the same
|
||||||
|
// output regardless of which side runs the function: the inputs are ordered
|
||||||
|
// lexicographically before concatenation.
|
||||||
|
//
|
||||||
|
// NetBird uses this value as the initial Rosenpass-side preshared key when no
|
||||||
|
// explicit account-level PSK is configured, so both peers converge on the same
|
||||||
|
// PSK before the first post-quantum handshake completes.
|
||||||
|
//
|
||||||
|
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
|
||||||
|
// from public keys and exists only to seed WireGuard until Rosenpass rotates
|
||||||
|
// in a real post-quantum PSK.
|
||||||
|
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
|
||||||
|
lk := []byte(localKey)
|
||||||
|
rk := []byte(remoteKey)
|
||||||
|
if len(lk) < 16 || len(rk) < 16 {
|
||||||
|
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyInput []byte
|
||||||
|
if localKey > remoteKey {
|
||||||
|
keyInput = append(keyInput, lk[:16]...)
|
||||||
|
keyInput = append(keyInput, rk[:16]...)
|
||||||
|
} else {
|
||||||
|
keyInput = append(keyInput, rk[:16]...)
|
||||||
|
keyInput = append(keyInput, lk[:16]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := wgtypes.NewKey(keyInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
|
||||||
|
}
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
44
client/internal/rosenpass/seed_test.go
Normal file
44
client/internal/rosenpass/seed_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
|
||||||
|
// Peer A and peer B must derive the same PSK regardless of which side
|
||||||
|
// computes it: the function orders inputs internally.
|
||||||
|
a := strings.Repeat("a", 32)
|
||||||
|
b := strings.Repeat("b", 32)
|
||||||
|
|
||||||
|
keyAB, err := DeterministicSeedKey(a, b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyBA, err := DeterministicSeedKey(b, a)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
|
||||||
|
a := strings.Repeat("a", 32)
|
||||||
|
b := strings.Repeat("b", 32)
|
||||||
|
c := strings.Repeat("c", 32)
|
||||||
|
|
||||||
|
keyAB, err := DeterministicSeedKey(a, b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyAC, err := DeterministicSeedKey(a, c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||||
|
short := "short" // < 16 bytes
|
||||||
|
long := strings.Repeat("x", 32)
|
||||||
|
|
||||||
|
_, err := DeterministicSeedKey(short, long)
|
||||||
|
require.Error(t, err)
|
||||||
|
_, err = DeterministicSeedKey(long, short)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
cancel := m.cancel
|
||||||
|
done := m.done
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
if m.cancel == nil {
|
if cancel == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.cancel()
|
cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-m.done:
|
case <-done:
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2709,6 +2709,7 @@ type DebugBundleRequest struct {
|
|||||||
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
||||||
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
|
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
|
||||||
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
|
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
|
||||||
|
CliVersion string `protobuf:"bytes,6,opt,name=cliVersion,proto3" json:"cliVersion,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -2771,6 +2772,13 @@ func (x *DebugBundleRequest) GetLogFileCount() uint32 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *DebugBundleRequest) GetCliVersion() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.CliVersion
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
type DebugBundleResponse struct {
|
type DebugBundleResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
|
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
|
||||||
@@ -6475,14 +6483,17 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
|
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
|
||||||
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
|
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
|
||||||
"\x17ForwardingRulesResponse\x12,\n" +
|
"\x17ForwardingRulesResponse\x12,\n" +
|
||||||
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
|
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xb4\x01\n" +
|
||||||
"\x12DebugBundleRequest\x12\x1c\n" +
|
"\x12DebugBundleRequest\x12\x1c\n" +
|
||||||
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
|
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"systemInfo\x18\x03 \x01(\bR\n" +
|
"systemInfo\x18\x03 \x01(\bR\n" +
|
||||||
"systemInfo\x12\x1c\n" +
|
"systemInfo\x12\x1c\n" +
|
||||||
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
|
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
|
||||||
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
|
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\x12\x1e\n" +
|
||||||
|
"\n" +
|
||||||
|
"cliVersion\x18\x06 \x01(\tR\n" +
|
||||||
|
"cliVersion\"}\n" +
|
||||||
"\x13DebugBundleResponse\x12\x12\n" +
|
"\x13DebugBundleResponse\x12\x12\n" +
|
||||||
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
|
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
|
||||||
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +
|
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +
|
||||||
|
|||||||
@@ -471,6 +471,7 @@ message DebugBundleRequest {
|
|||||||
bool systemInfo = 3;
|
bool systemInfo = 3;
|
||||||
string uploadURL = 4;
|
string uploadURL = 4;
|
||||||
uint32 logFileCount = 5;
|
uint32 logFileCount = 5;
|
||||||
|
string cliVersion = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DebugBundleResponse {
|
message DebugBundleResponse {
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
if ! which realpath > /dev/null 2>&1
|
if ! which realpath >/dev/null 2>&1; then
|
||||||
then
|
echo realpath is not installed
|
||||||
echo realpath is not installed
|
echo run: brew install coreutils
|
||||||
echo run: brew install coreutils
|
exit 1
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
old_pwd=$(pwd)
|
old_pwd=$(pwd)
|
||||||
script_path=$(dirname $(realpath "$0"))
|
script_path=$(dirname "$(realpath "$0")")
|
||||||
cd "$script_path"
|
cd "$script_path"
|
||||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
||||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
|
||||||
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
||||||
cd "$old_pwd"
|
cd "$old_pwd"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DebugBundle creates a debug bundle and returns the location.
|
// DebugBundle creates a debug bundle and returns the location.
|
||||||
@@ -67,6 +68,8 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
CapturePath: capturePath,
|
CapturePath: capturePath,
|
||||||
RefreshStatus: refreshStatus,
|
RefreshStatus: refreshStatus,
|
||||||
ClientMetrics: clientMetrics,
|
ClientMetrics: clientMetrics,
|
||||||
|
DaemonVersion: version.NetbirdVersion(),
|
||||||
|
CliVersion: req.CliVersion,
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
Anonymize: req.GetAnonymize(),
|
Anonymize: req.GetAnonymize(),
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -547,6 +547,16 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
daemonVersion := "N/A"
|
||||||
|
if o.DaemonVersion != "" {
|
||||||
|
daemonVersion = o.DaemonVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
cliVersion := version.NetbirdVersion()
|
||||||
|
if o.CliVersion != "" {
|
||||||
|
cliVersion = o.CliVersion
|
||||||
|
}
|
||||||
|
|
||||||
summary := fmt.Sprintf(
|
summary := fmt.Sprintf(
|
||||||
"OS: %s\n"+
|
"OS: %s\n"+
|
||||||
"Daemon version: %s\n"+
|
"Daemon version: %s\n"+
|
||||||
@@ -567,8 +577,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
"%s"+
|
"%s"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||||
o.DaemonVersion,
|
daemonVersion,
|
||||||
version.NetbirdVersion(),
|
cliVersion,
|
||||||
o.ProfileName,
|
o.ProfileName,
|
||||||
managementConnString,
|
managementConnString,
|
||||||
signalConnString,
|
signalConnString,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initial state for the debug collection
|
// Initial state for the debug collection
|
||||||
@@ -462,6 +463,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
|
|||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: params.anonymize,
|
Anonymize: params.anonymize,
|
||||||
SystemInfo: params.systemInfo,
|
SystemInfo: params.systemInfo,
|
||||||
|
CliVersion: version.NetbirdVersion(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.upload {
|
if params.upload {
|
||||||
@@ -593,6 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymize,
|
Anonymize: anonymize,
|
||||||
SystemInfo: systemInfo,
|
SystemInfo: systemInfo,
|
||||||
|
CliVersion: version.NetbirdVersion(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if uploadURL != "" {
|
if uploadURL != "" {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
certValidationTimeout = 60 * time.Second
|
certValidationTimeout = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||||
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
|||||||
|
|
||||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||||
|
|
||||||
resultChan := make(chan bool)
|
resultChan := make(chan bool, 1)
|
||||||
errorChan := make(chan error)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
// Release from inside the callbacks so a post-timeout promise resolution
|
||||||
result := args[0].Bool()
|
// does not invoke an already-released func.
|
||||||
resultChan <- result
|
var thenFn, catchFn js.Func
|
||||||
|
var releaseOnce sync.Once
|
||||||
|
release := func() {
|
||||||
|
releaseOnce.Do(func() {
|
||||||
|
thenFn.Release()
|
||||||
|
catchFn.Release()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
defer release()
|
||||||
|
resultChan <- args[0].Bool()
|
||||||
return nil
|
return nil
|
||||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
})
|
||||||
|
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
defer release()
|
||||||
errorChan <- fmt.Errorf("certificate validation failed")
|
errorChan <- fmt.Errorf("certificate validation failed")
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
promise.Call("then", thenFn).Call("catch", catchFn)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case result := <-resultChan:
|
case result := <-resultChan:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
|
|||||||
}
|
}
|
||||||
activeConnections map[string]*proxyConnection
|
activeConnections map[string]*proxyConnection
|
||||||
destinations map[string]string
|
destinations map[string]string
|
||||||
|
pendingHandlers map[string]js.Func
|
||||||
|
nextID atomic.Uint64
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,8 +69,15 @@ type proxyConnection struct {
|
|||||||
rdpConn net.Conn
|
rdpConn net.Conn
|
||||||
tlsConn *tls.Conn
|
tlsConn *tls.Conn
|
||||||
wsHandlers js.Value
|
wsHandlers js.Value
|
||||||
ctx context.Context
|
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||||
cancel context.CancelFunc
|
// global handle map and MUST be released, otherwise every connection
|
||||||
|
// leaks the Go memory the closure captures.
|
||||||
|
wsHandlerFn js.Func
|
||||||
|
onMessageFn js.Func
|
||||||
|
onCloseFn js.Func
|
||||||
|
cleanupOnce sync.Once
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||||
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProxy creates a new proxy endpoint for the given destination
|
// CreateProxy creates a new proxy endpoint for the given destination.
|
||||||
|
// The registered handler fn and its destinations/pendingHandlers entries are
|
||||||
|
// only released once a connection is established and cleanupConnection runs.
|
||||||
|
// If a caller invokes CreateProxy but never connects to the returned URL,
|
||||||
|
// those entries stay pinned for the lifetime of the page.
|
||||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||||
destination := net.JoinHostPort(hostname, port)
|
destination := net.JoinHostPort(hostname, port)
|
||||||
|
|
||||||
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
resolve := args[0]
|
resolve := args[0]
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
|
||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
if p.destinations == nil {
|
if p.destinations == nil {
|
||||||
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||||
|
|
||||||
// Register the WebSocket handler for this specific proxy
|
// Register the WebSocket handler for this specific proxy
|
||||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return js.ValueOf("error: requires WebSocket argument")
|
return js.ValueOf("error: requires WebSocket argument")
|
||||||
}
|
}
|
||||||
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
ws := args[0]
|
ws := args[0]
|
||||||
p.HandleWebSocketConnection(ws, proxyID)
|
p.HandleWebSocketConnection(ws, proxyID)
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.pendingHandlers == nil {
|
||||||
|
p.pendingHandlers = make(map[string]js.Func)
|
||||||
|
}
|
||||||
|
p.pendingHandlers[proxyID] = handlerFn
|
||||||
|
p.mu.Unlock()
|
||||||
|
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
|
||||||
|
|
||||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||||
resolve.Invoke(proxyURL)
|
resolve.Invoke(proxyURL)
|
||||||
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
|||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
p.activeConnections[proxyID] = conn
|
p.activeConnections[proxyID] = conn
|
||||||
|
if fn, ok := p.pendingHandlers[proxyID]; ok {
|
||||||
|
conn.wsHandlerFn = fn
|
||||||
|
delete(p.pendingHandlers, proxyID)
|
||||||
|
}
|
||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
|
|
||||||
p.setupWebSocketHandlers(ws, conn)
|
p.setupWebSocketHandlers(ws, conn)
|
||||||
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
|
|||||||
data := args[0]
|
data := args[0]
|
||||||
go p.handleWebSocketMessage(conn, data)
|
go p.handleWebSocketMessage(conn, data)
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
ws.Set("onGoMessage", conn.onMessageFn)
|
||||||
|
|
||||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
log.Debug("WebSocket closed by JavaScript")
|
log.Debug("WebSocket closed by JavaScript")
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
ws.Set("onGoClose", conn.onCloseFn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||||
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||||
log.Debugf("Cleaning up connection %s", conn.id)
|
conn.cleanupOnce.Do(func() {
|
||||||
conn.cancel()
|
log.Debugf("Cleaning up connection %s", conn.id)
|
||||||
if conn.tlsConn != nil {
|
conn.cancel()
|
||||||
log.Debug("Closing TLS connection")
|
if conn.tlsConn != nil {
|
||||||
if err := conn.tlsConn.Close(); err != nil {
|
log.Debug("Closing TLS connection")
|
||||||
log.Debugf("Error closing TLS connection: %v", err)
|
if err := conn.tlsConn.Close(); err != nil {
|
||||||
|
log.Debugf("Error closing TLS connection: %v", err)
|
||||||
|
}
|
||||||
|
conn.tlsConn = nil
|
||||||
}
|
}
|
||||||
conn.tlsConn = nil
|
if conn.rdpConn != nil {
|
||||||
}
|
log.Debug("Closing TCP connection")
|
||||||
if conn.rdpConn != nil {
|
if err := conn.rdpConn.Close(); err != nil {
|
||||||
log.Debug("Closing TCP connection")
|
log.Debugf("Error closing TCP connection: %v", err)
|
||||||
if err := conn.rdpConn.Close(); err != nil {
|
}
|
||||||
log.Debugf("Error closing TCP connection: %v", err)
|
conn.rdpConn = nil
|
||||||
}
|
}
|
||||||
conn.rdpConn = nil
|
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
|
||||||
}
|
|
||||||
p.mu.Lock()
|
// Detach before releasing so late JS calls surface as TypeError instead
|
||||||
delete(p.activeConnections, conn.id)
|
// of silent "call to released function".
|
||||||
p.mu.Unlock()
|
if conn.wsHandlers.Truthy() {
|
||||||
|
conn.wsHandlers.Set("onGoMessage", js.Undefined())
|
||||||
|
conn.wsHandlers.Set("onGoClose", js.Undefined())
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsHandlerFn may be zero-value if the pending handler lookup missed.
|
||||||
|
if conn.wsHandlerFn.Truthy() {
|
||||||
|
conn.wsHandlerFn.Release()
|
||||||
|
}
|
||||||
|
if conn.onMessageFn.Truthy() {
|
||||||
|
conn.onMessageFn.Release()
|
||||||
|
}
|
||||||
|
if conn.onCloseFn.Truthy() {
|
||||||
|
conn.onCloseFn.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.activeConnections, conn.id)
|
||||||
|
delete(p.destinations, conn.id)
|
||||||
|
delete(p.pendingHandlers, conn.id)
|
||||||
|
p.mu.Unlock()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
func CreateJSInterface(client *Client) js.Value {
|
func CreateJSInterface(client *Client) js.Value {
|
||||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||||
|
|
||||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return js.ValueOf(false)
|
return js.ValueOf(false)
|
||||||
}
|
}
|
||||||
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
|
|||||||
|
|
||||||
_, err := client.Write(bytes)
|
_, err := client.Write(bytes)
|
||||||
return js.ValueOf(err == nil)
|
return js.ValueOf(err == nil)
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("write", writeFunc)
|
||||||
|
|
||||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return js.ValueOf(false)
|
return js.ValueOf(false)
|
||||||
}
|
}
|
||||||
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
|
|||||||
rows := args[1].Int()
|
rows := args[1].Int()
|
||||||
err := client.Resize(cols, rows)
|
err := client.Resize(cols, rows)
|
||||||
return js.ValueOf(err == nil)
|
return js.ValueOf(err == nil)
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("resize", resizeFunc)
|
||||||
|
|
||||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
client.Close()
|
client.Close()
|
||||||
return js.Undefined()
|
return js.Undefined()
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("close", closeFunc)
|
||||||
|
|
||||||
go readLoop(client, jsInterface)
|
go func() {
|
||||||
|
readLoop(client, jsInterface)
|
||||||
|
// Detach before releasing so late JS calls surface as TypeError instead
|
||||||
|
// of silent "call to released function".
|
||||||
|
jsInterface.Set("write", js.Undefined())
|
||||||
|
jsInterface.Set("resize", js.Undefined())
|
||||||
|
jsInterface.Set("close", js.Undefined())
|
||||||
|
writeFunc.Release()
|
||||||
|
resizeFunc.Release()
|
||||||
|
closeFunc.Release()
|
||||||
|
}()
|
||||||
|
|
||||||
return jsInterface
|
return jsInterface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
|||||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||||
if servers.relaySrv != nil {
|
if servers.relaySrv != nil {
|
||||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||||
}
|
}
|
||||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
var relayAcceptFn func(conn listener.Conn)
|
var relayAcceptFn func(conn listener.Conn)
|
||||||
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
|||||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embedded IdP (Dex)
|
||||||
|
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||||
|
idpHandler.ServeHTTP(w, r)
|
||||||
|
|
||||||
// Management HTTP API (default)
|
// Management HTTP API (default)
|
||||||
default:
|
default:
|
||||||
httpHandler.ServeHTTP(w, r)
|
httpHandler.ServeHTTP(w, r)
|
||||||
|
|||||||
14
go.mod
14
go.mod
@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
|
|||||||
go 1.25.5
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cunicu.li/go-rosenpass v0.4.0
|
cunicu.li/go-rosenpass v0.5.42
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0
|
github.com/cenkalti/backoff/v4 v4.3.0
|
||||||
github.com/cloudflare/circl v1.3.3 // indirect
|
github.com/cloudflare/circl v1.3.3 // indirect
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
@@ -19,18 +19,18 @@ require (
|
|||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.50.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/sys v0.43.0
|
golang.org/x/sys v0.43.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/grpc v1.80.0
|
google.golang.org/grpc v1.80.0
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
fyne.io/fyne/v2 v2.7.0
|
fyne.io/fyne/v2 v2.7.0
|
||||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
|
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
|
||||||
|
github.com/DeRuina/timberjack v1.4.2
|
||||||
github.com/awnumar/memguard v0.23.0
|
github.com/awnumar/memguard v0.23.0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||||
@@ -38,7 +38,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||||
github.com/c-robinson/iplib v1.0.3
|
github.com/c-robinson/iplib v1.0.3
|
||||||
github.com/caddyserver/certmagic v0.21.3
|
github.com/caddyserver/certmagic v0.21.3
|
||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.19.0
|
||||||
github.com/coder/websocket v1.8.14
|
github.com/coder/websocket v1.8.14
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/coreos/go-oidc/v3 v3.18.0
|
github.com/coreos/go-oidc/v3 v3.18.0
|
||||||
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/google/go-cmp v0.7.0
|
github.com/google/go-cmp v0.7.0
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/nftables v0.3.0
|
github.com/google/nftables v0.3.0
|
||||||
github.com/gopacket/gopacket v1.1.1
|
github.com/gopacket/gopacket v1.4.0
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
|||||||
30
go.sum
30
go.sum
@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
|
|||||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
||||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
|
||||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
|
||||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||||
@@ -29,6 +29,8 @@ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+
|
|||||||
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
||||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||||
|
github.com/DeRuina/timberjack v1.4.2 h1:4bKlzhKdsR+2oNkgef9mqb4n11ICow8VK88RfzJPzN8=
|
||||||
|
github.com/DeRuina/timberjack v1.4.2/go.mod h1:RLoeQrwrCGIEF8gO5nV5b/gMD0QIy7bzQhBUgpp1EqE=
|
||||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||||
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||||
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
|
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
|
||||||
@@ -111,8 +113,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
|
|||||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
||||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
@@ -225,8 +227,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
|
|||||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
|
||||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
@@ -307,8 +309,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
|
|||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||||
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
|
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
|
||||||
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
|
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
|
||||||
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
||||||
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
||||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||||
@@ -390,6 +392,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
|
|||||||
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
|
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
|
||||||
|
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
|
||||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||||
@@ -499,8 +503,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||||
@@ -900,8 +904,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||||
@@ -938,8 +942,6 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8
|
|||||||
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
||||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
|
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ func (c *Controller) CountStreams() int {
|
|||||||
return c.peersUpdateManager.CountStreams()
|
return c.peersUpdateManager.CountStreams()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -175,6 +175,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.accountManagerMetrics != nil {
|
||||||
|
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
semaphore <- struct{}{}
|
semaphore <- struct{}{}
|
||||||
go func(p *nbpeer.Peer) {
|
go func(p *nbpeer.Peer) {
|
||||||
@@ -242,14 +246,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -265,7 +269,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
|||||||
if c.accountManagerMetrics != nil {
|
if c.accountManagerMetrics != nil {
|
||||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
}
|
}
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||||
@@ -359,14 +363,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
|||||||
found = true
|
found = true
|
||||||
select {
|
select {
|
||||||
case channel <- update:
|
case channel <- update:
|
||||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||||
default:
|
default:
|
||||||
dropped = true
|
dropped = true
|
||||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
"github.com/rs/cors"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -19,7 +21,6 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
@@ -27,16 +28,20 @@ import (
|
|||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util/crypt"
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const apiPrefix = "/api"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
kaep = keepalive.EnforcementPolicy{
|
kaep = keepalive.EnforcementPolicy{
|
||||||
MinTime: 15 * time.Second,
|
MinTime: 15 * time.Second,
|
||||||
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) EventStore() activity.Store {
|
func (s *BaseServer) EventStore() activity.Store {
|
||||||
return Create(s, func() activity.Store {
|
return Create(s, func() activity.Store {
|
||||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
var err error
|
||||||
if err != nil {
|
key := s.Config.DataStoreEncryptionKey
|
||||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
if key == "" {
|
||||||
|
log.Debugf("generate new activity store encryption key")
|
||||||
|
key, err = crypt.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize event store: %v", err)
|
log.Fatalf("failed to initialize event store: %v", err)
|
||||||
}
|
}
|
||||||
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||||
|
// the deployment isn't using the embedded variant.
|
||||||
|
func (s *BaseServer) IDPHandler() http.Handler {
|
||||||
|
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||||
|
if !ok || embeddedIdP == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) Router() *mux.Router {
|
||||||
|
return Create(s, func() *mux.Router {
|
||||||
|
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||||
return Create(s, func() *middleware.APIRateLimiter {
|
return Create(s, func() *middleware.APIRateLimiter {
|
||||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
integratedPeerValidator, err := validator.NewIntegratedValidator(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
s.PeersManager(),
|
s.PeersManager(),
|
||||||
s.SettingsManager(),
|
s.SettingsManager(),
|
||||||
|
|||||||
@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
|||||||
|
|
||||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||||
return Create(s, func() permissions.Manager {
|
return Create(s, func() permissions.Manager {
|
||||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
return permissions.NewManager(s.Store())
|
||||||
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
|
||||||
manager.SetAccountManager(s.AccountManager())
|
|
||||||
})
|
|
||||||
|
|
||||||
return manager
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
|||||||
return idpManager
|
return idpManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
|||||||
return &m
|
return &m
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||||
switch {
|
switch {
|
||||||
case s.certManager != nil:
|
case s.certManager != nil:
|
||||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
|||||||
log.Tracef("custom handler set successfully")
|
log.Tracef("custom handler set successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||||
// Check if a custom handler was set (for multiplexing additional services)
|
// Check if a custom handler was set (for multiplexing additional services)
|
||||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||||
if handler, ok := customHandler.(http.Handler); ok {
|
if handler, ok := customHandler.(http.Handler); ok {
|
||||||
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
|||||||
gRPCHandler.ServeHTTP(writer, request)
|
gRPCHandler.ServeHTTP(writer, request)
|
||||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||||
wsProxy.Handler().ServeHTTP(writer, request)
|
wsProxy.Handler().ServeHTTP(writer, request)
|
||||||
|
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||||
|
idpHandler.ServeHTTP(writer, request)
|
||||||
default:
|
default:
|
||||||
httpHandler.ServeHTTP(writer, request)
|
httpHandler.ServeHTTP(writer, request)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||||
if debouncer.ProcessUpdate(update) {
|
if debouncer.ProcessUpdate(update) {
|
||||||
// Send immediately (first update or after quiet period)
|
// Send immediately (first update or after quiet period)
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,15 +15,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
@@ -32,12 +30,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||||
|
|
||||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -56,17 +52,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
|
||||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiPrefix = "/api"
|
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
accountManager.GetUserFromUserAuth,
|
accountManager.GetUserFromUserAuth,
|
||||||
rateLimiter,
|
rateLimiter,
|
||||||
appMetrics.GetMeter(),
|
appMetrics.GetMeter(),
|
||||||
|
isValidChildAccount,
|
||||||
)
|
)
|
||||||
|
|
||||||
corsMiddleware := cors.AllowAll()
|
corsMiddleware := cors.AllowAll()
|
||||||
|
|
||||||
rootRouter := mux.NewRouter()
|
|
||||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||||
|
|
||||||
prefix := apiPrefix
|
|
||||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
|
||||||
|
|
||||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if embedded IdP is enabled for instance manager
|
|
||||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
|
||||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||||
}
|
}
|
||||||
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
oauthHandler.RegisterEndpoints(router)
|
oauthHandler.RegisterEndpoints(router)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mount embedded IdP handler at /oauth2 path if configured
|
return router, nil
|
||||||
if embeddedIdpEnabled {
|
|
||||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
|
||||||
}
|
|
||||||
|
|
||||||
return rootRouter, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
|||||||
|
|
||||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
|
|
||||||
|
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
authManager serverauth.Manager
|
authManager serverauth.Manager
|
||||||
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
|
|||||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||||
rateLimiter *APIRateLimiter
|
rateLimiter *APIRateLimiter
|
||||||
patUsageTracker *PATUsageTracker
|
patUsageTracker *PATUsageTracker
|
||||||
|
isValidChildAccount IsValidChildAccountFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
|
|||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
rateLimiter *APIRateLimiter,
|
rateLimiter *APIRateLimiter,
|
||||||
meter metric.Meter,
|
meter metric.Meter,
|
||||||
|
isValidChildAccount IsValidChildAccountFunc,
|
||||||
) *AuthMiddleware {
|
) *AuthMiddleware {
|
||||||
var patUsageTracker *PATUsageTracker
|
var patUsageTracker *PATUsageTracker
|
||||||
if meter != nil {
|
if meter != nil {
|
||||||
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
|
|||||||
getUserFromUserAuth: getUserFromUserAuth,
|
getUserFromUserAuth: getUserFromUserAuth,
|
||||||
rateLimiter: rateLimiter,
|
rateLimiter: rateLimiter,
|
||||||
patUsageTracker: patUsageTracker,
|
patUsageTracker: patUsageTracker,
|
||||||
|
isValidChildAccount: isValidChildAccount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||||
userAuth.AccountId = impersonate[0]
|
userAuth.AccountId = impersonate[0]
|
||||||
userAuth.IsChild = true
|
userAuth.IsChild = true
|
||||||
}
|
}
|
||||||
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||||
userAuth.AccountId = impersonate[0]
|
userAuth.AccountId = impersonate[0]
|
||||||
userAuth.IsChild = true
|
userAuth.IsChild = true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
},
|
},
|
||||||
disabledLimiter,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||||
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
},
|
},
|
||||||
disabledLimiter,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"go.opentelemetry.io/otel/metric/noop"
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||||
|
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||||
|
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package validator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IntegratedValidatorImpl struct{}
|
||||||
|
|
||||||
|
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
|
||||||
|
return &IntegratedValidatorImpl{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||||
|
return update, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
|
||||||
|
return peer.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
|
||||||
|
return false, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
|
||||||
|
validatedPeers := make(map[string]struct{})
|
||||||
|
for _, p := range peers {
|
||||||
|
validatedPeers[p.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
return validatedPeers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
|
||||||
|
return make(map[string]string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
|
||||||
|
return flowResponse
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
)
|
)
|
||||||
@@ -33,9 +32,6 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
|
||||||
peer.ID, peer.Meta.WtVersion, n.MinVersion)
|
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -100,8 +100,6 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
|
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +123,5 @@ func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, ch
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
|
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
updateAccountPeersDurationMs metric.Float64Histogram
|
updateAccountPeersDurationMs metric.Float64Histogram
|
||||||
updateAccountPeersCounter metric.Int64Counter
|
updateAccountPeersCounter metric.Int64Counter
|
||||||
|
nmapCounter metric.Int64Counter
|
||||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||||
networkMapObjectCount metric.Int64Histogram
|
networkMapObjectCount metric.Int64Histogram
|
||||||
peerMetaUpdateCount metric.Int64Counter
|
peerMetaUpdateCount metric.Int64Counter
|
||||||
@@ -59,6 +60,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nmapCounter, err := meter.Int64Counter("management.network.map.counter",
|
||||||
|
metric.WithUnit("1"),
|
||||||
|
metric.WithDescription("Number of network maps computed, labeled by resource and operation trigger"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter",
|
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter",
|
||||||
metric.WithUnit("1"),
|
metric.WithUnit("1"),
|
||||||
metric.WithDescription("Number of updates with new meta data from the peers"))
|
metric.WithDescription("Number of updates with new meta data from the peers"))
|
||||||
@@ -93,6 +101,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
|||||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||||
peerStatusUpdateCounter: peerStatusUpdateCounter,
|
peerStatusUpdateCounter: peerStatusUpdateCounter,
|
||||||
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
|
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
|
||||||
|
nmapCounter: nmapCounter,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -145,6 +154,16 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountNmapTriggered increments the counter for calculated network maps with resource and operation labels.
|
||||||
|
func (metrics *AccountManagerMetrics) CountNmapTriggered(resource, operation string) {
|
||||||
|
metrics.nmapCounter.Add(metrics.ctx, 1,
|
||||||
|
metric.WithAttributes(
|
||||||
|
attribute.String("resource", resource),
|
||||||
|
attribute.String("operation", operation),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// CountPeerMetUpdate counts the number of peer meta updates
|
// CountPeerMetUpdate counts the number of peer meta updates
|
||||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||||
|
|||||||
@@ -762,7 +762,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the initiator still has admin privileges
|
// Ensure the initiator still has admin privileges
|
||||||
if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() {
|
if !freshInitiator.HasAdminPower() {
|
||||||
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
|
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
|
||||||
}
|
}
|
||||||
initiatorUser = freshInitiator
|
initiatorUser = freshInitiator
|
||||||
@@ -906,19 +906,23 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !initiatorUser.HasAdminPower() {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
|
||||||
|
}
|
||||||
|
|
||||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||||
}
|
}
|
||||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
||||||
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||||
}
|
}
|
||||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
if initiatorUser.Role != types.UserRoleOwner && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||||
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||||
}
|
}
|
||||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
if oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||||
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||||
}
|
}
|
||||||
if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
if initiatorUser.Role != types.UserRoleOwner && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||||
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||||
}
|
}
|
||||||
if oldUser.IsServiceUser && update.Role == types.UserRoleOwner {
|
if oldUser.IsServiceUser && update.Role == types.UserRoleOwner {
|
||||||
|
|||||||
@@ -109,6 +109,22 @@ var debugStopCmd = &cobra.Command{
|
|||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var debugPerfCmd = &cobra.Command{
|
||||||
|
Use: "perf <pool-cap>",
|
||||||
|
Short: "Live-retune the tunnel buffer pool cap on all running clients",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: runDebugPerfSet,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var debugRuntimeCmd = &cobra.Command{
|
||||||
|
Use: "runtime",
|
||||||
|
Short: "Show runtime stats (heap, goroutines, RSS)",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: runDebugRuntime,
|
||||||
|
SilenceUsage: true,
|
||||||
|
}
|
||||||
|
|
||||||
var debugCaptureCmd = &cobra.Command{
|
var debugCaptureCmd = &cobra.Command{
|
||||||
Use: "capture <account-id> [filter expression]",
|
Use: "capture <account-id> [filter expression]",
|
||||||
Short: "Capture packets on a client's WireGuard interface",
|
Short: "Capture packets on a client's WireGuard interface",
|
||||||
@@ -159,6 +175,8 @@ func init() {
|
|||||||
debugCmd.AddCommand(debugLogCmd)
|
debugCmd.AddCommand(debugLogCmd)
|
||||||
debugCmd.AddCommand(debugStartCmd)
|
debugCmd.AddCommand(debugStartCmd)
|
||||||
debugCmd.AddCommand(debugStopCmd)
|
debugCmd.AddCommand(debugStopCmd)
|
||||||
|
debugCmd.AddCommand(debugPerfCmd)
|
||||||
|
debugCmd.AddCommand(debugRuntimeCmd)
|
||||||
debugCmd.AddCommand(debugCaptureCmd)
|
debugCmd.AddCommand(debugCaptureCmd)
|
||||||
|
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
@@ -220,6 +238,18 @@ func runDebugStop(cmd *cobra.Command, args []string) error {
|
|||||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runDebugPerfSet(cmd *cobra.Command, args []string) error {
|
||||||
|
n, err := strconv.ParseUint(args[0], 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid value %q: %w", args[0], err)
|
||||||
|
}
|
||||||
|
return getDebugClient(cmd).PerfSet(cmd.Context(), uint32(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
|
||||||
|
return getDebugClient(cmd).Runtime(cmd.Context())
|
||||||
|
}
|
||||||
|
|
||||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||||
duration, _ := cmd.Flags().GetDuration("duration")
|
duration, _ := cmd.Flags().GetDuration("duration")
|
||||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||||
|
|||||||
@@ -15,11 +15,22 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/embed"
|
||||||
"github.com/netbirdio/netbird/proxy"
|
"github.com/netbirdio/netbird/proxy"
|
||||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// envPreallocatedBuffers caps the per-tunnel buffer pool. Zero (unset)
|
||||||
|
// keeps the upstream uncapped default.
|
||||||
|
envPreallocatedBuffers = "NB_PROXY_PREALLOCATED_BUFFERS"
|
||||||
|
// envMaxBatchSize overrides the per-tunnel batch size, which controls
|
||||||
|
// how many buffers each receive/TUN worker eagerly allocates. Zero
|
||||||
|
// (unset) keeps the platform default.
|
||||||
|
envMaxBatchSize = "NB_PROXY_MAX_BATCH_SIZE"
|
||||||
|
)
|
||||||
|
|
||||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||||
|
|
||||||
// envProxyToken is the environment variable name for the proxy access token.
|
// envProxyToken is the environment variable name for the proxy access token.
|
||||||
@@ -148,6 +159,45 @@ func runServer(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
logger.Infof("configured log level: %s", level)
|
logger.Infof("configured log level: %s", level)
|
||||||
|
|
||||||
|
var wgPool, wgBatch uint64
|
||||||
|
var perf embed.Performance
|
||||||
|
if raw := os.Getenv(envPreallocatedBuffers); raw != "" {
|
||||||
|
n, err := strconv.ParseUint(raw, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid %s %q: %w", envPreallocatedBuffers, raw, err)
|
||||||
|
}
|
||||||
|
wgPool = n
|
||||||
|
v := uint32(n)
|
||||||
|
perf.PreallocatedBuffersPerPool = &v
|
||||||
|
logger.Infof("tunnel preallocated buffers per pool: %d", n)
|
||||||
|
}
|
||||||
|
if raw := os.Getenv(envMaxBatchSize); raw != "" {
|
||||||
|
n, err := strconv.ParseUint(raw, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid %s %q: %w", envMaxBatchSize, raw, err)
|
||||||
|
}
|
||||||
|
wgBatch = n
|
||||||
|
v := uint32(n)
|
||||||
|
perf.MaxBatchSize = &v
|
||||||
|
logger.Infof("tunnel max batch size override: %d", n)
|
||||||
|
}
|
||||||
|
if wgPool > 0 {
|
||||||
|
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
|
||||||
|
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
|
||||||
|
// the lifetime of the Device. A pool cap below that floor blocks
|
||||||
|
// the receive pipeline at startup.
|
||||||
|
batch := wgBatch
|
||||||
|
if batch == 0 {
|
||||||
|
batch = 128
|
||||||
|
}
|
||||||
|
const recvGoroutines = 4
|
||||||
|
floor := batch * recvGoroutines
|
||||||
|
if wgPool < floor {
|
||||||
|
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
|
||||||
|
envPreallocatedBuffers, wgPool, floor, batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch forwardedProto {
|
switch forwardedProto {
|
||||||
case "auto", "http", "https":
|
case "auto", "http", "https":
|
||||||
default:
|
default:
|
||||||
@@ -188,6 +238,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
|||||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||||
WildcardCertDir: wildcardCertDir,
|
WildcardCertDir: wildcardCertDir,
|
||||||
WireguardPort: wgPort,
|
WireguardPort: wgPort,
|
||||||
|
Performance: perf,
|
||||||
ProxyProtocol: proxyProtocol,
|
ProxyProtocol: proxyProtocol,
|
||||||
PreSharedKey: preSharedKey,
|
PreSharedKey: preSharedKey,
|
||||||
SupportsCustomPorts: supportsCustomPorts,
|
SupportsCustomPorts: supportsCustomPorts,
|
||||||
|
|||||||
@@ -333,6 +333,63 @@ func (c *Client) printLogLevelResult(data map[string]any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerfSet live-retunes the tunnel buffer pool cap on all running embedded
|
||||||
|
// clients. Batch size is not live-tunable; configure it at proxy startup.
|
||||||
|
func (c *Client) PerfSet(ctx context.Context, value uint32) error {
|
||||||
|
path := fmt.Sprintf("/debug/perf?value=%d", value)
|
||||||
|
return c.fetchAndPrint(ctx, path, c.printPerfSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printPerfSet(data map[string]any) {
|
||||||
|
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
|
||||||
|
c.printError(data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val, _ := data["value"].(float64)
|
||||||
|
applied, _ := data["applied"].(float64)
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Pool cap set to: %d\n", uint32(val))
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
|
||||||
|
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Failed:")
|
||||||
|
for k, v := range failed {
|
||||||
|
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runtime fetches runtime stats (heap, goroutines, RSS).
|
||||||
|
func (c *Client) Runtime(ctx context.Context) error {
|
||||||
|
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) printRuntime(data map[string]any) {
|
||||||
|
i := func(k string) uint64 {
|
||||||
|
v, _ := data[k].(float64)
|
||||||
|
return uint64(v)
|
||||||
|
}
|
||||||
|
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
|
||||||
|
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Heap:")
|
||||||
|
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
|
||||||
|
if _, ok := data["vm_rss"]; ok {
|
||||||
|
_, _ = fmt.Fprintln(c.out, "Process:")
|
||||||
|
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
|
||||||
|
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
|
||||||
|
}
|
||||||
|
|
||||||
// StartClient starts a specific client.
|
// StartClient starts a specific client.
|
||||||
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"maps"
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
|
|||||||
type clientProvider interface {
|
type clientProvider interface {
|
||||||
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||||
|
ListClientsForStartup() map[types.AccountID]*nbembed.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// InboundListenerInfo describes a per-account inbound listener as
|
// InboundListenerInfo describes a per-account inbound listener as
|
||||||
@@ -165,6 +168,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
h.handleListClients(w, r, wantJSON)
|
h.handleListClients(w, r, wantJSON)
|
||||||
case "/debug/health":
|
case "/debug/health":
|
||||||
h.handleHealth(w, r, wantJSON)
|
h.handleHealth(w, r, wantJSON)
|
||||||
|
case "/debug/perf":
|
||||||
|
h.handlePerf(w, r)
|
||||||
|
case "/debug/runtime":
|
||||||
|
h.handleRuntime(w, r)
|
||||||
default:
|
default:
|
||||||
if h.handleClientRoutes(w, r, path, wantJSON) {
|
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||||
return
|
return
|
||||||
@@ -258,10 +265,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
|||||||
}
|
}
|
||||||
|
|
||||||
if wantJSON {
|
if wantJSON {
|
||||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||||
for _, id := range sortedIDs {
|
for _, id := range sortedIDs {
|
||||||
info := clients[id]
|
info := clients[id]
|
||||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
clientsJSON = append(clientsJSON, map[string]any{
|
||||||
"account_id": info.AccountID,
|
"account_id": info.AccountID,
|
||||||
"service_count": info.ServiceCount,
|
"service_count": info.ServiceCount,
|
||||||
"service_keys": info.ServiceKeys,
|
"service_keys": info.ServiceKeys,
|
||||||
@@ -270,7 +277,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
|||||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
resp := map[string]interface{}{
|
resp := map[string]any{
|
||||||
"version": version.NetbirdVersion(),
|
"version": version.NetbirdVersion(),
|
||||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||||
"client_count": len(clients),
|
"client_count": len(clients),
|
||||||
@@ -352,10 +359,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
|||||||
if h.inbound != nil {
|
if h.inbound != nil {
|
||||||
inboundAll = h.inbound.InboundListeners()
|
inboundAll = h.inbound.InboundListeners()
|
||||||
}
|
}
|
||||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||||
for _, id := range sortedIDs {
|
for _, id := range sortedIDs {
|
||||||
info := clients[id]
|
info := clients[id]
|
||||||
row := map[string]interface{}{
|
row := map[string]any{
|
||||||
"account_id": info.AccountID,
|
"account_id": info.AccountID,
|
||||||
"service_count": info.ServiceCount,
|
"service_count": info.ServiceCount,
|
||||||
"service_keys": info.ServiceKeys,
|
"service_keys": info.ServiceKeys,
|
||||||
@@ -368,7 +375,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
|||||||
}
|
}
|
||||||
clientsJSON = append(clientsJSON, row)
|
clientsJSON = append(clientsJSON, row)
|
||||||
}
|
}
|
||||||
resp := map[string]interface{}{
|
resp := map[string]any{
|
||||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||||
"client_count": len(clients),
|
"client_count": len(clients),
|
||||||
"clients": clientsJSON,
|
"clients": clientsJSON,
|
||||||
@@ -458,7 +465,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
|||||||
})
|
})
|
||||||
|
|
||||||
if wantJSON {
|
if wantJSON {
|
||||||
resp := map[string]interface{}{
|
resp := map[string]any{
|
||||||
"account_id": accountID,
|
"account_id": accountID,
|
||||||
"status": overview.FullDetailSummary(),
|
"status": overview.FullDetailSummary(),
|
||||||
}
|
}
|
||||||
@@ -557,20 +564,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
|
|||||||
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
client, ok := h.provider.GetClient(accountID)
|
client, ok := h.provider.GetClient(accountID)
|
||||||
if !ok {
|
if !ok {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
host := r.URL.Query().Get("host")
|
host := r.URL.Query().Get("host")
|
||||||
portStr := r.URL.Query().Get("port")
|
portStr := r.URL.Query().Get("port")
|
||||||
if host == "" || portStr == "" {
|
if host == "" || portStr == "" {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
if err != nil || port < 1 || port > 65535 {
|
if err != nil || port < 1 || port > 65535 {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
h.writeJSON(w, map[string]any{"error": "invalid port"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -594,7 +601,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
|||||||
|
|
||||||
conn, err := client.Dial(ctx, network, address)
|
conn, err := client.Dial(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": false,
|
"success": false,
|
||||||
"host": host,
|
"host": host,
|
||||||
"port": port,
|
"port": port,
|
||||||
@@ -609,39 +616,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
resp := map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": true,
|
"success": true,
|
||||||
"host": host,
|
"host": host,
|
||||||
"port": port,
|
"port": port,
|
||||||
"remote": remote,
|
"remote": remote,
|
||||||
"latency_ms": latency.Milliseconds(),
|
"latency_ms": latency.Milliseconds(),
|
||||||
"latency": formatDuration(latency),
|
"latency": formatDuration(latency),
|
||||||
}
|
})
|
||||||
h.writeJSON(w, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
client, ok := h.provider.GetClient(accountID)
|
client, ok := h.provider.GetClient(accountID)
|
||||||
if !ok {
|
if !ok {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
level := r.URL.Query().Get("level")
|
level := r.URL.Query().Get("level")
|
||||||
if level == "" {
|
if level == "" {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.SetLogLevel(level); err != nil {
|
if err := client.SetLogLevel(level); err != nil {
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": false,
|
"success": false,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": true,
|
"success": true,
|
||||||
"level": level,
|
"level": level,
|
||||||
})
|
})
|
||||||
@@ -652,7 +658,7 @@ const clientActionTimeout = 30 * time.Second
|
|||||||
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
client, ok := h.provider.GetClient(accountID)
|
client, ok := h.provider.GetClient(accountID)
|
||||||
if !ok {
|
if !ok {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -660,14 +666,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := client.Start(ctx); err != nil {
|
if err := client.Start(ctx); err != nil {
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": false,
|
"success": false,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "client started",
|
"message": "client started",
|
||||||
})
|
})
|
||||||
@@ -676,7 +682,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
|||||||
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
client, ok := h.provider.GetClient(accountID)
|
client, ok := h.provider.GetClient(accountID)
|
||||||
if !ok {
|
if !ok {
|
||||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -684,19 +690,125 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := client.Stop(ctx); err != nil {
|
if err := client.Stop(ctx); err != nil {
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": false,
|
"success": false,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]any{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "client stopped",
|
"message": "client stopped",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handlePerf(w http.ResponseWriter, r *http.Request) {
|
||||||
|
raw := r.URL.Query().Get("value")
|
||||||
|
if raw == "" {
|
||||||
|
http.Error(w, "value parameter is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n, err := strconv.ParseUint(raw, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
capN := uint32(n)
|
||||||
|
applied := 0
|
||||||
|
failed := map[string]string{}
|
||||||
|
for accountID, client := range h.provider.ListClientsForStartup() {
|
||||||
|
if err := client.SetPerformance(nbembed.Performance{PreallocatedBuffersPerPool: &capN}); err != nil {
|
||||||
|
failed[string(accountID)] = err.Error()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
applied++
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"success": true,
|
||||||
|
"value": capN,
|
||||||
|
"applied": applied,
|
||||||
|
}
|
||||||
|
if len(failed) > 0 {
|
||||||
|
resp["failed"] = failed
|
||||||
|
}
|
||||||
|
h.writeJSON(w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
|
||||||
|
// running proxy; does not read pprof profiles.
|
||||||
|
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
var m runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&m)
|
||||||
|
|
||||||
|
clients := h.provider.ListClientsForDebug()
|
||||||
|
started := 0
|
||||||
|
for _, c := range clients {
|
||||||
|
if c.HasClient {
|
||||||
|
started++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||||
|
"goroutines": runtime.NumGoroutine(),
|
||||||
|
"num_cpu": runtime.NumCPU(),
|
||||||
|
"gomaxprocs": runtime.GOMAXPROCS(0),
|
||||||
|
"go_version": runtime.Version(),
|
||||||
|
"heap_alloc": m.HeapAlloc,
|
||||||
|
"heap_inuse": m.HeapInuse,
|
||||||
|
"heap_idle": m.HeapIdle,
|
||||||
|
"heap_released": m.HeapReleased,
|
||||||
|
"heap_sys": m.HeapSys,
|
||||||
|
"sys": m.Sys,
|
||||||
|
"live_objects": m.Mallocs - m.Frees,
|
||||||
|
"num_gc": m.NumGC,
|
||||||
|
"pause_total_ns": m.PauseTotalNs,
|
||||||
|
"clients": len(clients),
|
||||||
|
"started": started,
|
||||||
|
}
|
||||||
|
|
||||||
|
if proc := readProcStatus(); proc != nil {
|
||||||
|
resp["vm_rss"] = proc["VmRSS"]
|
||||||
|
resp["vm_size"] = proc["VmSize"]
|
||||||
|
resp["vm_data"] = proc["VmData"]
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readProcStatus parses /proc/self/status on Linux and returns size fields
|
||||||
|
// in bytes. Returns nil on non-Linux or read failure.
|
||||||
|
func readProcStatus() map[string]uint64 {
|
||||||
|
raw, err := os.ReadFile("/proc/self/status")
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := map[string]uint64{}
|
||||||
|
for _, line := range strings.Split(string(raw), "\n") {
|
||||||
|
k, v, ok := strings.Cut(line, ":")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := strings.Fields(v)
|
||||||
|
if len(fields) < 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.ParseUint(fields[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Values are reported in kB.
|
||||||
|
out[k] = n * 1024
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
const maxCaptureDuration = 30 * time.Minute
|
const maxCaptureDuration = 30 * time.Minute
|
||||||
|
|
||||||
// handleCapture streams a pcap or text packet capture for the given client.
|
// handleCapture streams a pcap or text packet capture for the given client.
|
||||||
@@ -825,7 +937,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
|
|||||||
h.writeJSON(w, resp)
|
h.writeJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
tmpl := h.getTemplates()
|
tmpl := h.getTemplates()
|
||||||
if tmpl == nil {
|
if tmpl == nil {
|
||||||
@@ -838,7 +950,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
enc := json.NewEncoder(w)
|
enc := json.NewEncoder(w)
|
||||||
enc.SetIndent("", " ")
|
enc.SetIndent("", " ")
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ type ClientConfig struct {
|
|||||||
MgmtAddr string
|
MgmtAddr string
|
||||||
WGPort uint16
|
WGPort uint16
|
||||||
PreSharedKey string
|
PreSharedKey string
|
||||||
|
Performance embed.Performance
|
||||||
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
|
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
|
||||||
// standalone proxy where the embedded client never accepts inbound;
|
// standalone proxy where the embedded client never accepts inbound;
|
||||||
// set to false on the private/embedded proxy so the engine creates
|
// set to false on the private/embedded proxy so the engine creates
|
||||||
@@ -306,7 +307,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
|||||||
ManagementURL: n.clientCfg.MgmtAddr,
|
ManagementURL: n.clientCfg.MgmtAddr,
|
||||||
PrivateKey: privateKey.String(),
|
PrivateKey: privateKey.String(),
|
||||||
LogLevel: log.WarnLevel.String(),
|
LogLevel: log.WarnLevel.String(),
|
||||||
BlockInbound: n.clientCfg.BlockInbound,
|
BlockInbound: n.clientCfg.BlockInbound,
|
||||||
// The embedded proxy peer must never be a stepping stone into
|
// The embedded proxy peer must never be a stepping stone into
|
||||||
// the proxy host's LAN: it only exists to reach NetBird mesh
|
// the proxy host's LAN: it only exists to reach NetBird mesh
|
||||||
// targets or, when direct_upstream is set, the host network
|
// targets or, when direct_upstream is set, the host network
|
||||||
@@ -315,6 +316,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
|||||||
BlockLANAccess: true,
|
BlockLANAccess: true,
|
||||||
WireguardPort: &wgPort,
|
WireguardPort: &wgPort,
|
||||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||||
|
Performance: n.clientCfg.Performance,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/embed"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,6 +90,10 @@ type Config struct {
|
|||||||
// PreSharedKey is the WireGuard pre-shared key used between the
|
// PreSharedKey is the WireGuard pre-shared key used between the
|
||||||
// proxy's embedded clients and peers.
|
// proxy's embedded clients and peers.
|
||||||
PreSharedKey string
|
PreSharedKey string
|
||||||
|
// Performance configures the tunnel pool/batch sizes for every
|
||||||
|
// embedded client this proxy creates. Zero values fall back to
|
||||||
|
// upstream defaults.
|
||||||
|
Performance embed.Performance
|
||||||
|
|
||||||
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
||||||
// ports for TCP/UDP/TLS services.
|
// ports for TCP/UDP/TLS services.
|
||||||
@@ -148,6 +153,7 @@ func New(cfg Config) *Server {
|
|||||||
WireguardPort: cfg.WireguardPort,
|
WireguardPort: cfg.WireguardPort,
|
||||||
ProxyProtocol: cfg.ProxyProtocol,
|
ProxyProtocol: cfg.ProxyProtocol,
|
||||||
PreSharedKey: cfg.PreSharedKey,
|
PreSharedKey: cfg.PreSharedKey,
|
||||||
|
Performance: cfg.Performance,
|
||||||
SupportsCustomPorts: cfg.SupportsCustomPorts,
|
SupportsCustomPorts: cfg.SupportsCustomPorts,
|
||||||
RequireSubdomain: cfg.RequireSubdomain,
|
RequireSubdomain: cfg.RequireSubdomain,
|
||||||
Private: cfg.Private,
|
Private: cfg.Private,
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ import (
|
|||||||
goproto "google.golang.org/protobuf/proto"
|
goproto "google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/embed"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||||
@@ -185,6 +186,9 @@ type Server struct {
|
|||||||
// single-account deployments; multiple accounts will fail to bind
|
// single-account deployments; multiple accounts will fail to bind
|
||||||
// the same port.
|
// the same port.
|
||||||
WireguardPort uint16
|
WireguardPort uint16
|
||||||
|
// Performance configures the tunnel pool/batch sizes for every
|
||||||
|
// embedded client this proxy spawns.
|
||||||
|
Performance embed.Performance
|
||||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||||
// When enabled, the real client IP is extracted from the PROXY header
|
// When enabled, the real client IP is extracted from the PROXY header
|
||||||
// sent by upstream L4 proxies that support PROXY protocol.
|
// sent by upstream L4 proxies that support PROXY protocol.
|
||||||
@@ -333,6 +337,8 @@ func (s *Server) Start(ctx context.Context) error {
|
|||||||
s.runCancel = runCancel
|
s.runCancel = runCancel
|
||||||
|
|
||||||
s.initNetBirdClient()
|
s.initNetBirdClient()
|
||||||
|
// Create health checker before the mapping worker so it can track
|
||||||
|
// management connectivity from the first stream connection.
|
||||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||||
|
|
||||||
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
|
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
|
||||||
@@ -529,6 +535,7 @@ func (s *Server) initNetBirdClient() {
|
|||||||
MgmtAddr: s.ManagementAddress,
|
MgmtAddr: s.ManagementAddress,
|
||||||
WGPort: s.WireguardPort,
|
WGPort: s.WireguardPort,
|
||||||
PreSharedKey: s.PreSharedKey,
|
PreSharedKey: s.PreSharedKey,
|
||||||
|
Performance: s.Performance,
|
||||||
// On --private the embedded client serves per-account inbound
|
// On --private the embedded client serves per-account inbound
|
||||||
// listeners and must apply management's ACL: keep BlockInbound off
|
// listeners and must apply management's ACL: keep BlockInbound off
|
||||||
// so the engine creates the ACL manager. On the standalone proxy
|
// so the engine creates the ACL manager. On the standalone proxy
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
50
util/log.go
50
util/log.go
@@ -1,15 +1,16 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/DeRuina/timberjack"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/grpclog"
|
"google.golang.org/grpc/grpclog"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
)
|
)
|
||||||
@@ -59,7 +60,12 @@ func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
|
|||||||
case "":
|
case "":
|
||||||
logger.Warnf("empty log path received: %#v", logPath)
|
logger.Warnf("empty log path received: %#v", logPath)
|
||||||
default:
|
default:
|
||||||
writers = append(writers, newRotatedOutput(logPath))
|
writer, err := setupLogFile(logPath, isRotationDisabled(logger))
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("failed setting up log file: %s, %s", logPath, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
writers = append(writers, writer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,17 +100,43 @@ func FindFirstLogPath(logs []string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isRotationDisabled(logger *log.Logger) bool {
|
||||||
|
v, _ := os.LookupEnv("NB_LOG_DISABLE_ROTATION")
|
||||||
|
disabled, _ := strconv.ParseBool(v)
|
||||||
|
if disabled {
|
||||||
|
logger.Warnf("log rotation is disabled by env flag")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
conflict, configPath := FindFirstLogrotateConflict()
|
||||||
|
if conflict {
|
||||||
|
logger.Warnf("log rotation conflict detected in: %#v, rotation is disabled", configPath)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLogFile(logPath string, disableRotation bool) (io.Writer, error) {
|
||||||
|
if disableRotation {
|
||||||
|
file, err := os.OpenFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed opening log file: %s", err)
|
||||||
|
}
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
return newRotatedOutput(logPath), nil
|
||||||
|
}
|
||||||
|
|
||||||
func newRotatedOutput(logPath string) io.Writer {
|
func newRotatedOutput(logPath string) io.Writer {
|
||||||
maxLogSize := getLogMaxSize()
|
maxLogSize := getLogMaxSize()
|
||||||
lumberjackLogger := &lumberjack.Logger{
|
timberjackLogger := &timberjack.Logger{
|
||||||
// Log file absolute path, os agnostic
|
// Log file absolute path, os agnostic
|
||||||
Filename: filepath.ToSlash(logPath),
|
Filename: filepath.ToSlash(logPath),
|
||||||
MaxSize: maxLogSize, // MB
|
MaxSize: maxLogSize, // MB
|
||||||
MaxBackups: 10,
|
MaxBackups: 10,
|
||||||
MaxAge: 30, // days
|
MaxAge: 30, // days
|
||||||
Compress: true,
|
Compression: "gzip",
|
||||||
}
|
}
|
||||||
return lumberjackLogger
|
return timberjackLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func setGRPCLibLogger(logger *log.Logger) {
|
func setGRPCLibLogger(logger *log.Logger) {
|
||||||
|
|||||||
96
util/log_test.go
Normal file
96
util/log_test.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSetupLogFile_RotatesOnSize drives >MaxSize bytes through the writer
|
||||||
|
// returned by setupLogFile and asserts a backup file appears.
|
||||||
|
func TestSetupLogFile_RotatesOnSize(t *testing.T) {
|
||||||
|
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
logPath := filepath.Join(dir, "netbird.log")
|
||||||
|
|
||||||
|
w, err := setupLogFile(logPath, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if c, ok := w.(io.Closer); ok {
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
chunk := []byte(strings.Repeat("x", 64*1024) + "\n")
|
||||||
|
for range 20 {
|
||||||
|
_, err := w.Write(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(logPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Less(t, info.Size(), int64(1<<20),
|
||||||
|
"active log should be < 1 MB after rotation, got %d", info.Size())
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
entries, _ := os.ReadDir(dir)
|
||||||
|
for _, e := range entries {
|
||||||
|
name := e.Name()
|
||||||
|
if name == filepath.Base(logPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(name, "netbird-") && strings.HasSuffix(name, ".log.gz") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}, 5*time.Second, 50*time.Millisecond, "expected a rotated backup file in %s", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSetupLogFile_RotationDisabled verifies that with rotation off, the file
|
||||||
|
// grows past MaxSize and no backups are created.
|
||||||
|
func TestSetupLogFile_RotationDisabled(t *testing.T) {
|
||||||
|
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
logPath := filepath.Join(dir, "netbird.log")
|
||||||
|
|
||||||
|
w, err := setupLogFile(logPath, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
f, ok := w.(*os.File)
|
||||||
|
require.True(t, ok, "expected plain *os.File when rotation is disabled, got %T", w)
|
||||||
|
t.Cleanup(func() { _ = f.Close() })
|
||||||
|
|
||||||
|
chunk := []byte(strings.Repeat("y", 64*1024) + "\n")
|
||||||
|
for range 20 {
|
||||||
|
_, err := w.Write(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(logPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.GreaterOrEqual(t, info.Size(), int64(1<<20),
|
||||||
|
"file should exceed MaxSize when rotation is disabled, got %d", info.Size())
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, entries, 1, "no backup files should exist when rotation is disabled, got %v", entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsRotationDisabled_EnvFlag covers the NB_LOG_DISABLE_ROTATION env path.
|
||||||
|
// The logrotate-conflict branch is exercised separately on linux.
|
||||||
|
func TestIsRotationDisabled_EnvFlag(t *testing.T) {
|
||||||
|
logger := log.New()
|
||||||
|
logger.SetOutput(io.Discard)
|
||||||
|
|
||||||
|
t.Setenv("NB_LOG_DISABLE_ROTATION", "true")
|
||||||
|
require.True(t, isRotationDisabled(logger))
|
||||||
|
}
|
||||||
93
util/logrotate_linux.go
Normal file
93
util/logrotate_linux.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultLogrotateConfPath = "/etc/logrotate.conf"
|
||||||
|
defaultLogrotateConfDir = "/etc/logrotate.d"
|
||||||
|
netbirdString = "netbird"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FindLogrotateConflicts scans the standard logrotate locations for
|
||||||
|
// indications of conflict with netbird. It returns true and the config file
|
||||||
|
// path if a conflict was found.
|
||||||
|
func FindFirstLogrotateConflict() (bool, string) {
|
||||||
|
return findFirstLogrotateConflictIn(defaultLogrotateConfPath, defaultLogrotateConfDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFirstLogrotateConflictIn(confPath, confDir string) (bool, string) {
|
||||||
|
for _, f := range listLogrotateConfigs(confPath, confDir) {
|
||||||
|
present, err := scanLogrotateFile(f, netbirdString)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
log.Debugf("scan %s: %v", f, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if present {
|
||||||
|
return present, f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// listLogrotateConfigs returns all config files for logrotate.
|
||||||
|
func listLogrotateConfigs(confPath, confDir string) []string {
|
||||||
|
files := []string{confPath}
|
||||||
|
entries, err := os.ReadDir(confDir)
|
||||||
|
if err != nil {
|
||||||
|
return files
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
files = append(files, filepath.Join(confDir, e.Name()))
|
||||||
|
}
|
||||||
|
return files
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanLogrotateFile reads a config and reports if a non-comment line
|
||||||
|
// contains the given substring.
|
||||||
|
func scanLogrotateFile(path string, substring string) (bool, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
log.Debugf("close %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(stripLogrotateComment(scanner.Text()))
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(line, substring) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripLogrotateComment(line string) string {
|
||||||
|
before, _, _ := strings.Cut(line, "#")
|
||||||
|
return before
|
||||||
|
}
|
||||||
95
util/logrotate_linux_test.go
Normal file
95
util/logrotate_linux_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFindFirstLogrotateConflict(t *testing.T) {
|
||||||
|
t.Run("conflict in confDir", func(t *testing.T) {
|
||||||
|
confPath, confDir := newLogrotateLayout(t)
|
||||||
|
conflictPath := filepath.Join(confDir, "netbird")
|
||||||
|
writeLogrotateConfig(t, conflictPath, `/var/log/netbird/*.log {
|
||||||
|
daily
|
||||||
|
rotate 7
|
||||||
|
}`)
|
||||||
|
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
|
||||||
|
|
||||||
|
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||||
|
require.True(t, got)
|
||||||
|
require.Equal(t, conflictPath, path)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("conflict in main conf file", func(t *testing.T) {
|
||||||
|
confPath, confDir := newLogrotateLayout(t)
|
||||||
|
writeLogrotateConfig(t, confPath, `weekly
|
||||||
|
rotate 4
|
||||||
|
include /etc/logrotate.d
|
||||||
|
/var/log/netbird/client.log { rotate 5 }`)
|
||||||
|
|
||||||
|
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||||
|
require.True(t, got)
|
||||||
|
require.Equal(t, confPath, path)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no conflict when netbird is absent", func(t *testing.T) {
|
||||||
|
confPath, confDir := newLogrotateLayout(t)
|
||||||
|
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
|
||||||
|
writeLogrotateConfig(t, filepath.Join(confDir, "syslog"), `/var/log/syslog { weekly }`)
|
||||||
|
|
||||||
|
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||||
|
require.False(t, got)
|
||||||
|
require.Empty(t, path)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("commented-out netbird line is ignored", func(t *testing.T) {
|
||||||
|
confPath, confDir := newLogrotateLayout(t)
|
||||||
|
writeLogrotateConfig(t, filepath.Join(confDir, "misc"), `# /var/log/netbird/*.log { daily }
|
||||||
|
/var/log/other.log { weekly }`)
|
||||||
|
|
||||||
|
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||||
|
require.False(t, got)
|
||||||
|
require.Empty(t, path)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subdirectories in confDir are ignored", func(t *testing.T) {
|
||||||
|
confPath, confDir := newLogrotateLayout(t)
|
||||||
|
sub := filepath.Join(confDir, "nested")
|
||||||
|
require.NoError(t, os.MkdirAll(sub, 0o755))
|
||||||
|
writeLogrotateConfig(t, filepath.Join(sub, "netbird"), `/var/log/netbird/*.log { daily }`)
|
||||||
|
|
||||||
|
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||||
|
require.False(t, got)
|
||||||
|
require.Empty(t, path)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing paths return no conflict", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
got, path := findFirstLogrotateConflictIn(
|
||||||
|
filepath.Join(dir, "does-not-exist.conf"),
|
||||||
|
filepath.Join(dir, "does-not-exist.d"),
|
||||||
|
)
|
||||||
|
require.False(t, got)
|
||||||
|
require.Empty(t, path)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// newLogrotateLayout creates a temp logrotate.conf path and logrotate.d dir,
|
||||||
|
// returning their paths. The conf file itself is not created.
|
||||||
|
func newLogrotateLayout(t *testing.T) (confPath, confDir string) {
|
||||||
|
t.Helper()
|
||||||
|
root := t.TempDir()
|
||||||
|
confDir = filepath.Join(root, "logrotate.d")
|
||||||
|
require.NoError(t, os.MkdirAll(confDir, 0o755))
|
||||||
|
return filepath.Join(root, "logrotate.conf"), confDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLogrotateConfig(t *testing.T, path, body string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte(body), 0o644))
|
||||||
|
}
|
||||||
10
util/logrotate_nonlinux.go
Normal file
10
util/logrotate_nonlinux.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
// FindLogrotateConflicts scans the standard logrotate locations for
|
||||||
|
// indications of conflict with netbird. It will always return false for
|
||||||
|
// non-linux devices.
|
||||||
|
func FindFirstLogrotateConflict() (bool, string) {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user