mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-28 03:29:55 +00:00
Compare commits
7 Commits
feat/dev_v
...
feature/io
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ea0882975 | ||
|
|
944a258459 | ||
|
|
1f9a829f2c | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 |
@@ -11,7 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -12,13 +12,7 @@ var (
|
||||
Short: "Print the NetBird's client application version",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
out := version.NetbirdVersion()
|
||||
if version.IsDevelopmentVersion(out) {
|
||||
if commit := version.NetbirdCommit(); commit != "" {
|
||||
out += "-" + commit
|
||||
}
|
||||
}
|
||||
cmd.Println(out)
|
||||
cmd.Println(version.NetbirdVersion())
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -100,6 +101,26 @@ type Options struct {
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||
Performance Performance
|
||||
}
|
||||
|
||||
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||
//
|
||||
// These settings are process-global: any non-nil field also becomes the
|
||||
// default for Clients constructed by later embed.New calls in the same
|
||||
// process. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||
// leaves the pool unbounded. Lower values trade throughput for a
|
||||
// tighter memory ceiling. May also be changed on a running Client via
|
||||
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||
// writes per syscall, which also bounds eager buffer allocation per
|
||||
// worker. Zero uses the platform default. Applied at construction
|
||||
// only; ignored by Client.SetPerformance.
|
||||
MaxBatchSize *uint32
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -199,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||
}
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
@@ -495,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||
// takes effect, and only when it was nonzero at construction;
|
||||
// MaxBatchSize is construction-only and returns an error if set here.
|
||||
//
|
||||
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||
// running yet.
|
||||
func (c *Client) SetPerformance(t Performance) error {
|
||||
if t.MaxBatchSize != nil {
|
||||
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||
}
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetPerformance(internal.Performance{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
|
||||
@@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||
// See Engine.SetPerformance. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetPerformance applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetPerformance(t Performance) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -66,8 +66,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -13,7 +11,7 @@ var (
|
||||
)
|
||||
|
||||
func IsSupported(agentVersion string) bool {
|
||||
if nbversion.IsDevelopmentVersion(agentVersion) {
|
||||
if agentVersion == "development" {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
|
||||
const (
|
||||
latestVersion = "latest"
|
||||
// this version will be ignored
|
||||
developmentVersion = "development"
|
||||
)
|
||||
|
||||
var errNoUpdateState = errors.New("no update state found")
|
||||
@@ -481,7 +483,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
|
||||
}
|
||||
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
|
||||
if version.IsDevelopmentVersion(m.currentVersion) {
|
||||
if m.currentVersion == developmentVersion {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -76,6 +76,9 @@ type Client struct {
|
||||
dnsManager dns.IosDnsManager
|
||||
loginComplete bool
|
||||
connectClient *internal.ConnectClient
|
||||
// config holds the active configuration once Run has loaded it. Consumed by
|
||||
// the in-app SSH client for the NetBird SSH key and the OAuth flow.
|
||||
config *profilemanager.Config
|
||||
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||
preloadedConfig *profilemanager.Config
|
||||
}
|
||||
@@ -160,6 +163,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
c.config = cfg
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
@@ -527,6 +531,13 @@ func (c *Client) DeselectRoute(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshState returns the active config and the running connect client for the
|
||||
// in-app SSH client. Both are nil until Run has loaded the config and started
|
||||
// the tunnel.
|
||||
func (c *Client) sshState() (*profilemanager.Config, *internal.ConnectClient) {
|
||||
return c.config, c.connectClient
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
ds := d.String()
|
||||
dotIndex := strings.Index(ds, ".")
|
||||
|
||||
431
client/ios/NetBirdSDK/ssh_client.go
Normal file
431
client/ios/NetBirdSDK/ssh_client.go
Normal file
@@ -0,0 +1,431 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
sshDetectionTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SSHTerminalListener receives SSH session events. It is implemented in Swift.
|
||||
//
|
||||
// All callbacks are invoked from goroutines and may run concurrently with each
|
||||
// other; the implementation must be safe to call from any thread.
|
||||
type SSHTerminalListener interface {
|
||||
OnConnected()
|
||||
OnData(data []byte)
|
||||
OnClose(reason string)
|
||||
OnError(message string)
|
||||
}
|
||||
|
||||
// SSHClient is a NetBird-aware SSH client exposed to Swift via gomobile.
|
||||
//
|
||||
// It dials through the running NetBird tunnel and runs a standard SSH session
|
||||
// on top with PTY enabled. Host-key verification uses the NetBird-provided
|
||||
// peer SSH host keys, identical to the desktop client.
|
||||
type SSHClient struct {
|
||||
nb *Client
|
||||
mu sync.Mutex
|
||||
listener SSHTerminalListener
|
||||
urlOpener URLOpener
|
||||
|
||||
sshClient *gossh.Client
|
||||
session *gossh.Session
|
||||
stdin io.WriteCloser
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
|
||||
func NewSSHClient(c *Client) *SSHClient {
|
||||
return &SSHClient{nb: c}
|
||||
}
|
||||
|
||||
// SetListener registers the Swift listener. Must be called before Connect to
|
||||
// receive any events.
|
||||
func (s *SSHClient) SetListener(l SSHTerminalListener) {
|
||||
s.mu.Lock()
|
||||
s.listener = l
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetURLOpener registers the Swift URL opener used to display the device-code
|
||||
// authorization page in an in-app browser when the target peer requires JWT
|
||||
// authentication. Must be set before Connect to be effective.
|
||||
func (s *SSHClient) SetURLOpener(opener URLOpener) {
|
||||
s.mu.Lock()
|
||||
s.urlOpener = opener
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Connect dials the SSH server through the NetBird tunnel and performs the
|
||||
// SSH handshake. It auto-detects the server type via SSH banner inspection
|
||||
// and selects the appropriate authentication path:
|
||||
//
|
||||
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
|
||||
// flow, opens the verification URL through the registered URLOpener, and
|
||||
// uses the resulting token as the SSH password. Host-key verification
|
||||
// uses the NetBird peer registry.
|
||||
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
|
||||
// private key. Host-key verification uses the NetBird peer registry.
|
||||
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
|
||||
// first (so a user-installed NetBird public key works), then falls back
|
||||
// to the supplied password if non-empty. Host-key verification is
|
||||
// disabled (TOFU pending).
|
||||
//
|
||||
// The password parameter is only consulted for regular SSH servers.
|
||||
func (s *SSHClient) Connect(host string, port int, user, password string) error {
|
||||
cfg, cc := s.nb.sshState()
|
||||
if cc == nil {
|
||||
return errors.New("netbird client not running")
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("netbird config not loaded")
|
||||
}
|
||||
engine := cc.Engine()
|
||||
if engine == nil {
|
||||
return errors.New("netbird engine not available")
|
||||
}
|
||||
|
||||
serverType := detectServerType(host, port)
|
||||
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
|
||||
|
||||
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientConfig := &gossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
return s.dialAndHandshake(host, port, clientConfig)
|
||||
}
|
||||
|
||||
// StartSession requests a PTY and starts an interactive shell. Output from
|
||||
// the session is forwarded to the listener via OnData.
|
||||
func (s *SSHClient) StartSession(cols, rows int) error {
|
||||
log.Debugf("SSH: starting session %dx%d", cols, rows)
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
s.mu.Unlock()
|
||||
|
||||
if sshClient == nil {
|
||||
return errors.New("ssh client not connected")
|
||||
}
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
modes := gossh.TerminalModes{
|
||||
gossh.ECHO: 1,
|
||||
gossh.TTY_OP_ISPEED: 14400,
|
||||
gossh.TTY_OP_OSPEED: 14400,
|
||||
gossh.VINTR: 3,
|
||||
gossh.VQUIT: 28,
|
||||
gossh.VERASE: 127,
|
||||
}
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeQuiet(session, "session after pty error")
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdin error")
|
||||
return fmt.Errorf("stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdout error")
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := session.StderrPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stderr error")
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeQuiet(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.session = session
|
||||
s.stdin = stdin
|
||||
s.mu.Unlock()
|
||||
|
||||
go s.readLoop(stdout, "stdout")
|
||||
go s.readLoop(stderr, "stderr")
|
||||
log.Debug("SSH: session started, shell running")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session stdin.
|
||||
func (s *SSHClient) Write(data []byte) error {
|
||||
s.mu.Lock()
|
||||
stdin := s.stdin
|
||||
s.mu.Unlock()
|
||||
if stdin == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write stdin: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resize updates the PTY window size.
|
||||
func (s *SSHClient) Resize(cols, rows int) error {
|
||||
s.mu.Lock()
|
||||
session := s.session
|
||||
s.mu.Unlock()
|
||||
if session == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close terminates the SSH session and underlying connection. Safe to call
|
||||
// multiple times.
|
||||
func (s *SSHClient) Close() error {
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
session := s.session
|
||||
stdin := s.stdin
|
||||
s.sshClient = nil
|
||||
s.session = nil
|
||||
s.stdin = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if stdin != nil {
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("ssh: stdin close: %v", err)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: session close: %v", err)
|
||||
}
|
||||
}
|
||||
var firstErr error
|
||||
if sshClient != nil {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
s.notifyClose("closed by client")
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
|
||||
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
|
||||
|
||||
switch serverType {
|
||||
case detection.ServerTypeNetBirdJWT:
|
||||
token, err := s.requestJWTToken(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("jwt: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.Password(token)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
case detection.ServerTypeNetBirdNoJWT:
|
||||
if cfg.SSHKey == "" {
|
||||
return nil, nil, errors.New("no NetBird SSH key available")
|
||||
}
|
||||
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
default: // regular SSH
|
||||
var auths []gossh.AuthMethod
|
||||
if cfg.SSHKey != "" {
|
||||
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
|
||||
auths = append(auths, gossh.PublicKeys(signer))
|
||||
} else {
|
||||
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
|
||||
}
|
||||
}
|
||||
if password != "" {
|
||||
pw := password
|
||||
auths = append(auths, gossh.Password(pw))
|
||||
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
answers[i] = pw
|
||||
}
|
||||
return answers, nil
|
||||
}))
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
|
||||
}
|
||||
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
|
||||
s.mu.Lock()
|
||||
urlOpener := s.urlOpener
|
||||
s.mu.Unlock()
|
||||
if urlOpener == nil {
|
||||
return "", errors.New("URL opener not configured for JWT auth")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create oauth flow: %w", err)
|
||||
}
|
||||
|
||||
flowInfo, err := flow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request auth info: %w", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for token: %w", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
if token == "" {
|
||||
return "", errors.New("empty token returned by IdP")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
|
||||
if err != nil {
|
||||
if cerr := conn.Close(); cerr != nil {
|
||||
log.Debugf("ssh: close after handshake error: %v", cerr)
|
||||
}
|
||||
return fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("SSH: connected to %s", addr)
|
||||
if listener != nil {
|
||||
listener.OnConnected()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) readLoop(r io.Reader, name string) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
s.mu.Lock()
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
listener.OnData(chunk)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh %s read: %v", name, err)
|
||||
}
|
||||
s.notifyClose(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) notifyClose(reason string) {
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
listener.OnClose(reason)
|
||||
}
|
||||
}
|
||||
|
||||
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
|
||||
type engineHostKeyVerifier struct {
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
|
||||
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return nbssh.ErrPeerNotFound
|
||||
}
|
||||
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
|
||||
}
|
||||
|
||||
func detectServerType(host string, port int) detection.ServerType {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
|
||||
return detection.ServerTypeRegular
|
||||
}
|
||||
return serverType
|
||||
}
|
||||
|
||||
func closeQuiet(c io.Closer, label string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: close %s: %v", label, err)
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
certValidationTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
resultChan := make(chan bool, 1)
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
// Release from inside the callbacks so a post-timeout promise resolution
|
||||
// does not invoke an already-released func.
|
||||
var thenFn, catchFn js.Func
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
thenFn.Release()
|
||||
catchFn.Release()
|
||||
})
|
||||
}
|
||||
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
defer release()
|
||||
resultChan <- args[0].Bool()
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
})
|
||||
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
defer release()
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
|
||||
promise.Call("then", thenFn).Call("catch", catchFn)
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
pendingHandlers map[string]js.Func
|
||||
nextID atomic.Uint64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -66,8 +69,15 @@ type proxyConnection struct {
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||
// global handle map and MUST be released, otherwise every connection
|
||||
// leaks the Go memory the closure captures.
|
||||
wsHandlerFn js.Func
|
||||
onMessageFn js.Func
|
||||
onCloseFn js.Func
|
||||
cleanupOnce sync.Once
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
// CreateProxy creates a new proxy endpoint for the given destination.
|
||||
// The registered handler fn and its destinations/pendingHandlers entries are
|
||||
// only released once a connection is established and cleanupConnection runs.
|
||||
// If a caller invokes CreateProxy but never connects to the returned URL,
|
||||
// those entries stay pinned for the lifetime of the page.
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := net.JoinHostPort(hostname, port)
|
||||
|
||||
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
p.mu.Lock()
|
||||
if p.pendingHandlers == nil {
|
||||
p.pendingHandlers = make(map[string]js.Func)
|
||||
}
|
||||
p.pendingHandlers[proxyID] = handlerFn
|
||||
p.mu.Unlock()
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
if fn, ok := p.pendingHandlers[proxyID]; ok {
|
||||
conn.wsHandlerFn = fn
|
||||
delete(p.pendingHandlers, proxyID)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
ws.Set("onGoMessage", conn.onMessageFn)
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
ws.Set("onGoClose", conn.onCloseFn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
conn.cleanupOnce.Do(func() {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
|
||||
|
||||
// Detach before releasing so late JS calls surface as TypeError instead
|
||||
// of silent "call to released function".
|
||||
if conn.wsHandlers.Truthy() {
|
||||
conn.wsHandlers.Set("onGoMessage", js.Undefined())
|
||||
conn.wsHandlers.Set("onGoClose", js.Undefined())
|
||||
}
|
||||
|
||||
// wsHandlerFn may be zero-value if the pending handler lookup missed.
|
||||
if conn.wsHandlerFn.Truthy() {
|
||||
conn.wsHandlerFn.Release()
|
||||
}
|
||||
if conn.onMessageFn.Truthy() {
|
||||
conn.onMessageFn.Release()
|
||||
}
|
||||
if conn.onCloseFn.Truthy() {
|
||||
conn.onCloseFn.Release()
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
delete(p.pendingHandlers, conn.id)
|
||||
p.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
func CreateJSInterface(client *Client) js.Value {
|
||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
|
||||
|
||||
_, err := client.Write(bytes)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
})
|
||||
jsInterface.Set("write", writeFunc)
|
||||
|
||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
|
||||
rows := args[1].Int()
|
||||
err := client.Resize(cols, rows)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
})
|
||||
jsInterface.Set("resize", resizeFunc)
|
||||
|
||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
client.Close()
|
||||
return js.Undefined()
|
||||
}))
|
||||
})
|
||||
jsInterface.Set("close", closeFunc)
|
||||
|
||||
go readLoop(client, jsInterface)
|
||||
go func() {
|
||||
readLoop(client, jsInterface)
|
||||
// Detach before releasing so late JS calls surface as TypeError instead
|
||||
// of silent "call to released function".
|
||||
jsInterface.Set("write", js.Undefined())
|
||||
jsInterface.Set("resize", js.Undefined())
|
||||
jsInterface.Set("close", js.Undefined())
|
||||
writeFunc.Release()
|
||||
resizeFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
|
||||
return jsInterface
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
}
|
||||
|
||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
var relayAcceptFn func(conn listener.Conn)
|
||||
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Embedded IdP (Dex)
|
||||
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(w, r)
|
||||
|
||||
// Management HTTP API (default)
|
||||
default:
|
||||
httpHandler.ServeHTTP(w, r)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||
|
||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -499,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
|
||||
@@ -32,7 +32,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
@@ -113,7 +112,7 @@ func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -176,6 +175,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
continue
|
||||
}
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
@@ -243,14 +246,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -266,7 +269,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
@@ -360,14 +363,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -511,7 +514,7 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if version.IsDevelopmentVersion(peer.Meta.WtVersion) {
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/rs/cors"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -19,7 +21,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
@@ -27,16 +28,20 @@ import (
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
|
||||
|
||||
func (s *BaseServer) EventStore() activity.Store {
|
||||
return Create(s, func() activity.Store {
|
||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
var err error
|
||||
key := s.Config.DataStoreEncryptionKey
|
||||
if key == "" {
|
||||
log.Debugf("generate new activity store encryption key")
|
||||
key, err = crypt.GenerateKey()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||
// the deployment isn't using the embedded variant.
|
||||
func (s *BaseServer) IDPHandler() http.Handler {
|
||||
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||
if !ok || embeddedIdP == nil {
|
||||
return nil
|
||||
}
|
||||
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||
}
|
||||
|
||||
func (s *BaseServer) Router() *mux.Router {
|
||||
return Create(s, func() *mux.Router {
|
||||
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
integratedPeerValidator, err := validator.NewIntegratedValidator(
|
||||
context.Background(),
|
||||
s.PeersManager(),
|
||||
s.SettingsManager(),
|
||||
|
||||
@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
|
||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||
return Create(s, func() permissions.Manager {
|
||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetAccountManager(s.AccountManager())
|
||||
})
|
||||
|
||||
return manager
|
||||
return permissions.NewManager(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return idpManager
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
||||
log.Tracef("custom handler set successfully")
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
// Check if a custom handler was set (for multiplexing additional services)
|
||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||
if handler, ok := customHandler.(http.Handler); ok {
|
||||
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
||||
gRPCHandler.ServeHTTP(writer, request)
|
||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||
wsProxy.Handler().ServeHTTP(writer, request)
|
||||
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(writer, request)
|
||||
default:
|
||||
httpHandler.ServeHTTP(writer, request)
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
// Send immediately (first update or after quiet period)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,15 +15,13 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -32,12 +30,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -56,17 +52,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
isValidChildAccount,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
prefix := apiPrefix
|
||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Check if embedded IdP is enabled for instance manager
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
oauthHandler.RegisterEndpoints(router)
|
||||
}
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
||||
}
|
||||
|
||||
return rootRouter, nil
|
||||
return router, nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager serverauth.Manager
|
||||
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
isValidChildAccount IsValidChildAccountFunc
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
isValidChildAccount IsValidChildAccountFunc,
|
||||
) *AuthMiddleware {
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
isValidChildAccount: isValidChildAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
|
||||
@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type IntegratedValidatorImpl struct{}
|
||||
|
||||
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
|
||||
return &IntegratedValidatorImpl{}, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
|
||||
return peer.Copy()
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for _, p := range peers {
|
||||
validatedPeers[p.ID] = struct{}{}
|
||||
}
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
|
||||
return flowResponse
|
||||
}
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const remoteJobsMinVer = "0.64.0"
|
||||
@@ -373,7 +372,7 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(remoteJobsMinVer, p.Meta.WtVersion)
|
||||
if !version.IsDevelopmentVersion(p.Meta.WtVersion) && (!meetMinVer || err != nil) {
|
||||
if !strings.Contains(p.Meta.WtVersion, "dev") && (!meetMinVer || err != nil) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer version %s does not meet the minimum required version %s for remote jobs", p.Meta.WtVersion, remoteJobsMinVer)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -33,9 +32,6 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
||||
peer.ID, peer.Meta.WtVersion, n.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -100,8 +100,6 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -125,7 +123,5 @@ func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, ch
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
|
||||
ctx context.Context
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
updateAccountPeersCounter metric.Int64Counter
|
||||
nmapCounter metric.Int64Counter
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
@@ -59,6 +60,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nmapCounter, err := meter.Int64Counter("management.network.map.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of network maps computed, labeled by resource and operation trigger"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of updates with new meta data from the peers"))
|
||||
@@ -93,6 +101,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
peerStatusUpdateCounter: peerStatusUpdateCounter,
|
||||
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
|
||||
nmapCounter: nmapCounter,
|
||||
}, nil
|
||||
|
||||
}
|
||||
@@ -145,6 +154,16 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
|
||||
)
|
||||
}
|
||||
|
||||
// CountNmapTriggered increments the counter for calculated network maps with resource and operation labels.
|
||||
func (metrics *AccountManagerMetrics) CountNmapTriggered(resource, operation string) {
|
||||
metrics.nmapCounter.Add(metrics.ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("resource", resource),
|
||||
attribute.String("operation", operation),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// CountPeerMetUpdate counts the number of peer meta updates
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1805,7 +1804,7 @@ func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *n
|
||||
|
||||
// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
|
||||
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
|
||||
if version.IsDevelopmentVersion(peerVer) {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return supportedFeatures{true, true}
|
||||
}
|
||||
|
||||
|
||||
@@ -646,7 +646,41 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
expectedPorts: []string{"20-25", "10-100", "22022"},
|
||||
},
|
||||
{
|
||||
name: "development version supports all features",
|
||||
name: "dev suffix version supports all features",
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
SSHEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.50.0-dev",
|
||||
Flags: nbpeer.Flags{ServerSSHAllowed: true},
|
||||
},
|
||||
},
|
||||
rule: &PolicyRule{
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
},
|
||||
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
|
||||
expectedPorts: []string{"22", "22022"},
|
||||
},
|
||||
{
|
||||
name: "dev suffix version supports all features",
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
SSHEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "dev",
|
||||
Flags: nbpeer.Flags{ServerSSHAllowed: true},
|
||||
},
|
||||
},
|
||||
rule: &PolicyRule{
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
},
|
||||
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
|
||||
expectedPorts: []string{"22", "22022"},
|
||||
},
|
||||
{
|
||||
name: "development suffix version supports all features",
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
SSHEnabled: true,
|
||||
|
||||
@@ -762,7 +762,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
|
||||
// Ensure the initiator still has admin privileges
|
||||
if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() {
|
||||
if !freshInitiator.HasAdminPower() {
|
||||
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
|
||||
}
|
||||
initiatorUser = freshInitiator
|
||||
@@ -906,19 +906,23 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
|
||||
return nil
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
|
||||
}
|
||||
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
if initiatorUser.Role != types.UserRoleOwner && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||
}
|
||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
if oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||
}
|
||||
if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
if initiatorUser.Role != types.UserRoleOwner && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||
}
|
||||
if oldUser.IsServiceUser && update.Role == types.UserRoleOwner {
|
||||
|
||||
@@ -109,6 +109,22 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugPerfCmd = &cobra.Command{
|
||||
Use: "perf <pool-cap>",
|
||||
Short: "Live-retune the tunnel buffer pool cap on all running clients",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugPerfSet,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugRuntimeCmd = &cobra.Command{
|
||||
Use: "runtime",
|
||||
Short: "Show runtime stats (heap, goroutines, RSS)",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runDebugRuntime,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugCaptureCmd = &cobra.Command{
|
||||
Use: "capture <account-id> [filter expression]",
|
||||
Short: "Capture packets on a client's WireGuard interface",
|
||||
@@ -159,6 +175,8 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugCmd.AddCommand(debugPerfCmd)
|
||||
debugCmd.AddCommand(debugRuntimeCmd)
|
||||
debugCmd.AddCommand(debugCaptureCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
@@ -220,6 +238,18 @@ func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugPerfSet(cmd *cobra.Command, args []string) error {
|
||||
n, err := strconv.ParseUint(args[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value %q: %w", args[0], err)
|
||||
}
|
||||
return getDebugClient(cmd).PerfSet(cmd.Context(), uint32(n))
|
||||
}
|
||||
|
||||
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).Runtime(cmd.Context())
|
||||
}
|
||||
|
||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||
duration, _ := cmd.Flags().GetDuration("duration")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
|
||||
@@ -15,11 +15,22 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// envPreallocatedBuffers caps the per-tunnel buffer pool. Zero (unset)
|
||||
// keeps the upstream uncapped default.
|
||||
envPreallocatedBuffers = "NB_PROXY_PREALLOCATED_BUFFERS"
|
||||
// envMaxBatchSize overrides the per-tunnel batch size, which controls
|
||||
// how many buffers each receive/TUN worker eagerly allocates. Zero
|
||||
// (unset) keeps the platform default.
|
||||
envMaxBatchSize = "NB_PROXY_MAX_BATCH_SIZE"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
// envProxyToken is the environment variable name for the proxy access token.
|
||||
@@ -148,6 +159,45 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
logger.Infof("configured log level: %s", level)
|
||||
|
||||
var wgPool, wgBatch uint64
|
||||
var perf embed.Performance
|
||||
if raw := os.Getenv(envPreallocatedBuffers); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envPreallocatedBuffers, raw, err)
|
||||
}
|
||||
wgPool = n
|
||||
v := uint32(n)
|
||||
perf.PreallocatedBuffersPerPool = &v
|
||||
logger.Infof("tunnel preallocated buffers per pool: %d", n)
|
||||
}
|
||||
if raw := os.Getenv(envMaxBatchSize); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envMaxBatchSize, raw, err)
|
||||
}
|
||||
wgBatch = n
|
||||
v := uint32(n)
|
||||
perf.MaxBatchSize = &v
|
||||
logger.Infof("tunnel max batch size override: %d", n)
|
||||
}
|
||||
if wgPool > 0 {
|
||||
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
|
||||
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
|
||||
// the lifetime of the Device. A pool cap below that floor blocks
|
||||
// the receive pipeline at startup.
|
||||
batch := wgBatch
|
||||
if batch == 0 {
|
||||
batch = 128
|
||||
}
|
||||
const recvGoroutines = 4
|
||||
floor := batch * recvGoroutines
|
||||
if wgPool < floor {
|
||||
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
|
||||
envPreallocatedBuffers, wgPool, floor, batch)
|
||||
}
|
||||
}
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
default:
|
||||
@@ -188,6 +238,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WildcardCertDir: wildcardCertDir,
|
||||
WireguardPort: wgPort,
|
||||
Performance: perf,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
PreSharedKey: preSharedKey,
|
||||
SupportsCustomPorts: supportsCustomPorts,
|
||||
|
||||
@@ -333,6 +333,63 @@ func (c *Client) printLogLevelResult(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// PerfSet live-retunes the tunnel buffer pool cap on all running embedded
|
||||
// clients. Batch size is not live-tunable; configure it at proxy startup.
|
||||
func (c *Client) PerfSet(ctx context.Context, value uint32) error {
|
||||
path := fmt.Sprintf("/debug/perf?value=%d", value)
|
||||
return c.fetchAndPrint(ctx, path, c.printPerfSet)
|
||||
}
|
||||
|
||||
func (c *Client) printPerfSet(data map[string]any) {
|
||||
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
|
||||
c.printError(data)
|
||||
return
|
||||
}
|
||||
val, _ := data["value"].(float64)
|
||||
applied, _ := data["applied"].(float64)
|
||||
_, _ = fmt.Fprintf(c.out, "Pool cap set to: %d\n", uint32(val))
|
||||
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
|
||||
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed:")
|
||||
for k, v := range failed {
|
||||
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runtime fetches runtime stats (heap, goroutines, RSS).
|
||||
func (c *Client) Runtime(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
|
||||
}
|
||||
|
||||
func (c *Client) printRuntime(data map[string]any) {
|
||||
i := func(k string) uint64 {
|
||||
v, _ := data[k].(float64)
|
||||
return uint64(v)
|
||||
}
|
||||
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
|
||||
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
|
||||
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
|
||||
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
|
||||
_, _ = fmt.Fprintln(c.out, "Heap:")
|
||||
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
|
||||
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
|
||||
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
|
||||
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
|
||||
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
|
||||
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
|
||||
if _, ok := data["vm_rss"]; ok {
|
||||
_, _ = fmt.Fprintln(c.out, "Process:")
|
||||
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
|
||||
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
|
||||
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
|
||||
}
|
||||
|
||||
// StartClient starts a specific client.
|
||||
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
|
||||
type clientProvider interface {
|
||||
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
ListClientsForStartup() map[types.AccountID]*nbembed.Client
|
||||
}
|
||||
|
||||
// InboundListenerInfo describes a per-account inbound listener as
|
||||
@@ -165,6 +168,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleListClients(w, r, wantJSON)
|
||||
case "/debug/health":
|
||||
h.handleHealth(w, r, wantJSON)
|
||||
case "/debug/perf":
|
||||
h.handlePerf(w, r)
|
||||
case "/debug/runtime":
|
||||
h.handleRuntime(w, r)
|
||||
default:
|
||||
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||
return
|
||||
@@ -258,10 +265,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
clientsJSON = append(clientsJSON, map[string]any{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
@@ -270,7 +277,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
resp := map[string]any{
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
@@ -352,10 +359,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
if h.inbound != nil {
|
||||
inboundAll = h.inbound.InboundListeners()
|
||||
}
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
row := map[string]interface{}{
|
||||
row := map[string]any{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
@@ -368,7 +375,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
}
|
||||
clientsJSON = append(clientsJSON, row)
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
resp := map[string]any{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
@@ -458,7 +465,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
})
|
||||
|
||||
if wantJSON {
|
||||
resp := map[string]interface{}{
|
||||
resp := map[string]any{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
}
|
||||
@@ -557,20 +564,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
|
||||
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
host := r.URL.Query().Get("host")
|
||||
portStr := r.URL.Query().Get("port")
|
||||
if host == "" || portStr == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
||||
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
|
||||
return
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
||||
h.writeJSON(w, map[string]any{"error": "invalid port"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -594,7 +601,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
|
||||
conn, err := client.Dial(ctx, network, address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"host": host,
|
||||
"port": port,
|
||||
@@ -609,39 +616,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
resp := map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"remote": remote,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"latency": formatDuration(latency),
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
level := r.URL.Query().Get("level")
|
||||
if level == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.SetLogLevel(level); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"level": level,
|
||||
})
|
||||
@@ -652,7 +658,7 @@ const clientActionTimeout = 30 * time.Second
|
||||
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -660,14 +666,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "client started",
|
||||
})
|
||||
@@ -676,7 +682,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
||||
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -684,19 +690,125 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "client stopped",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handlePerf(w http.ResponseWriter, r *http.Request) {
|
||||
raw := r.URL.Query().Get("value")
|
||||
if raw == "" {
|
||||
http.Error(w, "value parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
capN := uint32(n)
|
||||
applied := 0
|
||||
failed := map[string]string{}
|
||||
for accountID, client := range h.provider.ListClientsForStartup() {
|
||||
if err := client.SetPerformance(nbembed.Performance{PreallocatedBuffersPerPool: &capN}); err != nil {
|
||||
failed[string(accountID)] = err.Error()
|
||||
continue
|
||||
}
|
||||
applied++
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"success": true,
|
||||
"value": capN,
|
||||
"applied": applied,
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
resp["failed"] = failed
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
|
||||
// running proxy; does not read pprof profiles.
|
||||
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
started := 0
|
||||
for _, c := range clients {
|
||||
if c.HasClient {
|
||||
started++
|
||||
}
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"num_cpu": runtime.NumCPU(),
|
||||
"gomaxprocs": runtime.GOMAXPROCS(0),
|
||||
"go_version": runtime.Version(),
|
||||
"heap_alloc": m.HeapAlloc,
|
||||
"heap_inuse": m.HeapInuse,
|
||||
"heap_idle": m.HeapIdle,
|
||||
"heap_released": m.HeapReleased,
|
||||
"heap_sys": m.HeapSys,
|
||||
"sys": m.Sys,
|
||||
"live_objects": m.Mallocs - m.Frees,
|
||||
"num_gc": m.NumGC,
|
||||
"pause_total_ns": m.PauseTotalNs,
|
||||
"clients": len(clients),
|
||||
"started": started,
|
||||
}
|
||||
|
||||
if proc := readProcStatus(); proc != nil {
|
||||
resp["vm_rss"] = proc["VmRSS"]
|
||||
resp["vm_size"] = proc["VmSize"]
|
||||
resp["vm_data"] = proc["VmData"]
|
||||
}
|
||||
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
// readProcStatus parses /proc/self/status on Linux and returns size fields
|
||||
// in bytes. Returns nil on non-Linux or read failure.
|
||||
func readProcStatus() map[string]uint64 {
|
||||
raw, err := os.ReadFile("/proc/self/status")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := map[string]uint64{}
|
||||
for _, line := range strings.Split(string(raw), "\n") {
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(v)
|
||||
if len(fields) < 1 {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.ParseUint(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Values are reported in kB.
|
||||
out[k] = n * 1024
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
const maxCaptureDuration = 30 * time.Minute
|
||||
|
||||
// handleCapture streams a pcap or text packet capture for the given client.
|
||||
@@ -825,7 +937,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
tmpl := h.getTemplates()
|
||||
if tmpl == nil {
|
||||
@@ -838,7 +950,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
@@ -131,6 +131,7 @@ type ClientConfig struct {
|
||||
MgmtAddr string
|
||||
WGPort uint16
|
||||
PreSharedKey string
|
||||
Performance embed.Performance
|
||||
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
|
||||
// standalone proxy where the embedded client never accepts inbound;
|
||||
// set to false on the private/embedded proxy so the engine creates
|
||||
@@ -306,7 +307,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: n.clientCfg.BlockInbound,
|
||||
BlockInbound: n.clientCfg.BlockInbound,
|
||||
// The embedded proxy peer must never be a stepping stone into
|
||||
// the proxy host's LAN: it only exists to reach NetBird mesh
|
||||
// targets or, when direct_upstream is set, the host network
|
||||
@@ -315,6 +316,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
BlockLANAccess: true,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
Performance: n.clientCfg.Performance,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
)
|
||||
|
||||
@@ -89,6 +90,10 @@ type Config struct {
|
||||
// PreSharedKey is the WireGuard pre-shared key used between the
|
||||
// proxy's embedded clients and peers.
|
||||
PreSharedKey string
|
||||
// Performance configures the tunnel pool/batch sizes for every
|
||||
// embedded client this proxy creates. Zero values fall back to
|
||||
// upstream defaults.
|
||||
Performance embed.Performance
|
||||
|
||||
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
||||
// ports for TCP/UDP/TLS services.
|
||||
@@ -148,6 +153,7 @@ func New(cfg Config) *Server {
|
||||
WireguardPort: cfg.WireguardPort,
|
||||
ProxyProtocol: cfg.ProxyProtocol,
|
||||
PreSharedKey: cfg.PreSharedKey,
|
||||
Performance: cfg.Performance,
|
||||
SupportsCustomPorts: cfg.SupportsCustomPorts,
|
||||
RequireSubdomain: cfg.RequireSubdomain,
|
||||
Private: cfg.Private,
|
||||
|
||||
@@ -41,6 +41,7 @@ import (
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
@@ -185,6 +186,9 @@ type Server struct {
|
||||
// single-account deployments; multiple accounts will fail to bind
|
||||
// the same port.
|
||||
WireguardPort uint16
|
||||
// Performance configures the tunnel pool/batch sizes for every
|
||||
// embedded client this proxy spawns.
|
||||
Performance embed.Performance
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
// When enabled, the real client IP is extracted from the PROXY header
|
||||
// sent by upstream L4 proxies that support PROXY protocol.
|
||||
@@ -333,6 +337,8 @@ func (s *Server) Start(ctx context.Context) error {
|
||||
s.runCancel = runCancel
|
||||
|
||||
s.initNetBirdClient()
|
||||
// Create health checker before the mapping worker so it can track
|
||||
// management connectivity from the first stream connection.
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
|
||||
@@ -529,6 +535,7 @@ func (s *Server) initNetBirdClient() {
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
Performance: s.Performance,
|
||||
// On --private the embedded client serves per-account inbound
|
||||
// listeners and must apply management's ACL: keep BlockInbound off
|
||||
// so the engine creates the ACL manager. On the standalone proxy
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2,75 +2,19 @@ package version
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
// DevelopmentVersion is the value of NetbirdVersion() for non-release builds.
|
||||
// Wire-format consumers (management server, dashboard) match against this
|
||||
// string, so it must not change without coordinating those consumers.
|
||||
const DevelopmentVersion = "development"
|
||||
|
||||
// will be replaced with the release version when using goreleaser
|
||||
var version = DevelopmentVersion
|
||||
var version = "development"
|
||||
|
||||
var (
|
||||
VersionRegexp = regexp.MustCompile("^" + v.VersionRegexpRaw + "$")
|
||||
SemverRegexp = regexp.MustCompile("^" + v.SemverRegexpRaw + "$")
|
||||
)
|
||||
|
||||
// NetbirdVersion returns the Netbird version. For non-release builds the
|
||||
// value is the literal DevelopmentVersion constant; the VCS revision is
|
||||
// exposed separately via NetbirdCommit so the wire format stays stable.
|
||||
// NetbirdVersion returns the Netbird version
|
||||
func NetbirdVersion() string {
|
||||
return version
|
||||
}
|
||||
|
||||
// NetbirdCommit returns the VCS revision (truncated to 12 chars) of the
|
||||
// build, with a "-dirty" suffix when the working tree was modified.
|
||||
// Returns an empty string when no build info is embedded (e.g. release
|
||||
// builds compiled by goreleaser without -buildvcs).
|
||||
func NetbirdCommit() string {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
var revision string
|
||||
var modified bool
|
||||
for _, s := range info.Settings {
|
||||
switch s.Key {
|
||||
case "vcs.revision":
|
||||
revision = s.Value
|
||||
case "vcs.modified":
|
||||
modified = s.Value == "true"
|
||||
}
|
||||
}
|
||||
|
||||
if revision == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(revision) > 12 {
|
||||
revision = revision[:12]
|
||||
}
|
||||
|
||||
if modified {
|
||||
revision += "-dirty"
|
||||
}
|
||||
return revision
|
||||
}
|
||||
|
||||
// IsDevelopmentVersion reports whether the given version string identifies
|
||||
// a non-release / development build. It is the single source of truth for
|
||||
// "is this a dev build" checks across the codebase; use it instead of
|
||||
// comparing against the "development" literal or ad-hoc substring checks.
|
||||
//
|
||||
// Matches the bare DevelopmentVersion constant as well as any future
|
||||
// extension such as "development-<sha>" or "development-<sha>-dirty",
|
||||
// while excluding tagged prereleases like "v0.31.1-dev".
|
||||
func IsDevelopmentVersion(v string) bool {
|
||||
return strings.HasPrefix(v, DevelopmentVersion)
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package version
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsDevelopmentVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
version string
|
||||
want bool
|
||||
}{
|
||||
{"development", true},
|
||||
{"development-0823f3ff9ab1", true},
|
||||
{"development-0823f3ff9ab1-dirty", true},
|
||||
{"0.50.0", false},
|
||||
{"v0.31.1-dev", false},
|
||||
{"1.0.0-dev", false},
|
||||
{"dev", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.version, func(t *testing.T) {
|
||||
if got := IsDevelopmentVersion(tt.version); got != tt.want {
|
||||
t.Errorf("IsDevelopmentVersion(%q) = %v, want %v", tt.version, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user