Compare commits

..

15 Commits

Author SHA1 Message Date
Viktor Liu
dd301f2691 Add daemon socket owner enforcement via SO_PEERCRED 2026-05-29 13:28:37 +02:00
Zoltan Papp
174dc24867 [management] Add SSO session extend flow (management) (#6197)
* add SSO session extend flow (management)

Adds the management-server half of the SSO session-extension feature:

- New ExtendAuthSession gRPC RPC that refreshes a peer's session expiry
  using a fresh JWT, validated through the same pipeline as Login but
  without tearing down the tunnel or redoing the NetworkMap sync.
- Per-peer SessionExpiresAt timestamp on every LoginResponse and
  SyncResponse so connected clients learn the deadline on the existing
  long-lived stream, and admin-side changes (toggling expiration,
  changing the expiration window) reach every peer within seconds.
- SessionExpiresAt(...) helper on Peer that derives the absolute UTC
  deadline from LastLogin + the account-level PeerLoginExpiration
  setting, returning zero when the peer is not SSO-tracked or expiration
  is disabled.

The matching client-side consumer of these fields lands separately.

* encode SessionExpiresAt as 3-state on the wire

Previously the `sessionExpiresAt` field on LoginResponse, SyncResponse
and ExtendAuthSessionResponse was 2-state: a valid timestamp meant
"new deadline", and nil meant "clear". That conflated two distinct
meanings — "no info in this snapshot" vs "expiry is explicitly off /
peer is not SSO-tracked" — so a Sync push that legitimately couldn't
compute the deadline (settings lookup failed) would silently clear the
client's anchor and lose the warning window.

Three states now, encoded on the same field number (no .proto schema
churn — only comments and the server-side encoder change):

  - nil pointer (field absent) → "no info"; client preserves anchor
  - &Timestamp{} (seconds=0, nanos=0) → explicit "disabled / not SSO"
    sentinel; client clears
  - valid timestamp → new absolute UTC deadline

A new encodeSessionExpiresAt helper centralises the zero/non-zero
encoding and is shared by the Sync, Login and ExtendAuthSession
builders. The Sync builder still emits nil when settings are missing.
Login and ExtendAuthSession always carry an authoritative value.

The matching client-side decoder lands on feature/session-extend.

* add UserExtendedPeerSession activity event

ExtendAuthSession previously reused UserLoggedInPeer for its audit
record, which conflated two distinct user actions: a full interactive
SSO login (tunnel re-established, network map resync) versus an
in-place deadline refresh (tunnel untouched). Auditors reading the log
couldn't tell which one happened, and downstream dashboards/alerts on
"login" volume were polluted by routine extends.

Adds a dedicated UserExtendedPeerSession Activity (code 125,
"user.peer.session.extend") and switches ExtendPeerSession over to it.
The peer-extend audit trail is now distinguishable from interactive
logins.

* make ExtendAuthSession JWT-retry backoff cancellable

Skip the retry log and 200ms wait on the final attempt, and replace the
uncancellable time.Sleep with a select on time.After/ctx.Done so an
upstream cancellation aborts the wait instead of running it to
completion.
2026-05-28 19:14:14 +02:00
Riccardo Manfrin
7ea5e37dd4 [client] Improve rosenpass support (#6136)
* Updates rosenpass version

go-rosenpass v0.4.0 → v0.5.42 bump — detailed findings

Change summary
cunicu.li/go-rosenpass  v0.4.0  → v0.5.42   (target)
cilium/ebpf             v0.15.0 → v0.19.0   (transitive)
gopacket/gopacket       v1.1.1  → v1.4.0    (transitive)
wireguard               2023-07 → 2023-12   (transitive)
wireguard/wgctrl        2023-04 → 2024-12   (transitive)

Wire interop

v0.4.0 (in v0.70.5) <-> v0.5.42 OK
v0.5.42 <-> v0.5.42 OK

Quantum resistance: true both ends

---
**Replay error eliminated.**

Before (on v0.4.0):

`ERROR Failed to handle message: failed to load biscuit (ICR1): detected replay`

Recurring every ~50ms for minutes at a time. Gone entirely after both ends upgraded to v0.5.42. Upstream fix in biscuit/replay handling between v0.4.x and v0.5.x series.

* Fixup [::]:port socket trying to send to v4

* Adds more tests on netbird<->rosenpass interactions

* Anticipates rp handler creation before generateConfig

* [client] Moves deterministic key gen into rosenpass

* go mod tidy

* Adds reminder to reason about rosenpass surface area

* Apply code rabbit suggestions
2026-05-28 09:01:18 +02:00
Riccardo Manfrin
9d7ef9b255 [client] Fix statemanager possible deadlock (#6228)
1. Stop() takes m.mu.Lock() and defers m.mu.Unlock()
2. <-m.done under lock
3. periodicStateSave defers close(m.done)
4. periodicStateSave calls PersistState() (line 256) which does m.mu.Lock()

Double Stop() remains idempotent: second cancel() on dead ctx
 (no-op) and reads done already closed (immediate return).
2026-05-28 08:54:15 +02:00
Pascal Fischer
944a258459 [management] extend nmap monitoring (#6271) 2026-05-27 16:56:02 +02:00
Pascal Fischer
1f9a829f2c [management] update log levels (#6266) 2026-05-27 11:43:49 +02:00
Bethuel Mmbaga
14af179556 [management] Refactor management server bootstrap (#6256) 2026-05-26 17:44:28 +03:00
Pascal Fischer
1fbb5e6d5d [management] fix owner role update (#6264) 2026-05-26 16:37:58 +02:00
Viktor Liu
6771e35d57 [client] Release js.FuncOf callbacks in wasm ssh and rdp to prevent leaks (#5982) 2026-05-26 14:32:39 +02:00
Viktor Liu
e89b1e0596 [proxy, client] Bound embed client WireGuard per-Device memory (#5962) 2026-05-26 11:51:53 +02:00
Philip Laine
d542c60e21 Refactor Linux system info to use syscalls (#6230) 2026-05-25 21:00:24 +02:00
Viktor Liu
4983b5cf17 [client] Match DNS wildcard handlers on label boundaries (#6255) 2026-05-25 18:38:48 +02:00
Viktor Liu
b3b0feb3b8 [client] Filter scoped/cloned default routes from BSD network monitor RTM_ADD (#6208) 2026-05-25 18:38:21 +02:00
Maycon Santos
7aebdd69dd [management, client, proxy] add expose NetBird-only services over tunnel peers (#6226)
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
2026-05-25 17:41:50 +02:00
Viktor Liu
0358be2313 [client] Revert "Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176)" (#6232)
This reverts commit d927ef468a.
2026-05-21 16:27:12 +02:00
170 changed files with 12749 additions and 2345 deletions

View File

@@ -20,34 +20,66 @@ jobs:
per_page: 100,
});
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
return;
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
async function getVersionHeader(path, ref) {
try {
const res = await github.rest.repos.getContent({
owner: context.repo.owner,
repo: context.repo.repo,
path,
ref,
});
if (!res.data.content) {
return { ok: false, reason: 'no inline content (file too large)' };
}
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
const lines = content
.split('\n')
.slice(0, 20)
.filter(line => versionPattern.test(line));
return { ok: true, lines };
} catch (e) {
return { ok: false, reason: e.message };
}
}
const violations = [];
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
violations.push({
file: file.filename,
lines: changed,
base: base.lines,
head: head.lines,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
`${v.file}:\n` +
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
).join('\n\n');
core.setFailed(

View File

@@ -137,7 +137,7 @@ func (pm *ProfileManager) SwitchProfile(profileName string) error {
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
if err := pm.serviceMgr.AddProfile(profileName, androidUsername, nil); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}

84
client/cmd/owner.go Normal file
View File

@@ -0,0 +1,84 @@
package cmd
import (
"fmt"
"strconv"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
var ownerCmd = &cobra.Command{
Use: "owner",
Short: "Manage daemon owner UIDs",
Long: `Manage the list of UIDs allowed to control the NetBird daemon.
Owners are persisted in the active profile config and survive daemon restarts.
The first call from the user logged in at the GUI / console session claims
ownership automatically; these subcommands cover the rest of the lifecycle.`,
}
var ownerAddCmd = &cobra.Command{
Use: "add <uid>",
Short: "Add a UID as an owner of the daemon",
Long: `Add a UID to the active profile's owner list. Requires root or an
existing owner. Use this to grant another local user permanent access without
having them log in at the console first.`,
Args: cobra.ExactArgs(1),
RunE: addOwnerFunc,
}
var ownerResetCmd = &cobra.Command{
Use: "reset",
Short: "Clear the daemon's owner list",
Long: `Clear the active profile's owner list, returning the daemon to its
unconfigured state. The next call from the active console-session user will
re-claim ownership. Requires root.`,
RunE: resetOwnerFunc,
}
func addOwnerFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
uid, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
return fmt.Errorf("parse uid %q: %w", args[0], err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
if _, err := client.AddOwner(cmd.Context(), &proto.AddOwnerRequest{Uid: uint32(uid)}); err != nil {
return fmt.Errorf("add owner: %w", err)
}
cmd.Printf("UID %d added as owner\n", uid)
return nil
}
func resetOwnerFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
if _, err := client.ResetOwner(cmd.Context(), &proto.ResetOwnerRequest{}); err != nil {
return fmt.Errorf("reset owner: %w", err)
}
cmd.Println("daemon owner list cleared; next call from the active console user will re-claim ownership")
return nil
}

View File

@@ -23,6 +23,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/owner"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -156,8 +157,12 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(ownerCmd)
rootCmd.AddCommand(exposeCmd)
ownerCmd.AddCommand(ownerAddCmd)
ownerCmd.AddCommand(ownerResetCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -250,11 +255,24 @@ func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, e
return grpc.DialContext(
ctx,
strings.TrimPrefix(addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
daemonDialTransportOption(addr),
grpc.WithBlock(),
)
}
// daemonDialTransportOption returns the appropriate transport credentials for connecting
// to the daemon. On Unix socket platforms, uses Unix transport credentials so the server
// can extract the caller's UID for owner verification. Otherwise, uses insecure credentials.
func daemonDialTransportOption(addr string) grpc.DialOption {
if strings.HasPrefix(addr, "unix://") {
creds := owner.NewUnixTransportCredentials()
if creds != nil {
return grpc.WithTransportCredentials(creds)
}
}
return grpc.WithTransportCredentials(insecure.NewCredentials())
}
// WithBackOff execute function in backoff cycle.
func WithBackOff(bf func() error) error {
return backoff.RetryNotify(bf, CLIBackOffSettings, func(err error, duration time.Duration) {

View File

@@ -16,6 +16,7 @@ import (
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal/owner"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
@@ -29,9 +30,6 @@ func (p *program) Start(svc service.Service) error {
// Collect static system and platform information
system.UpdateStaticInfoAsync()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()
split := strings.Split(daemonAddr, "://")
switch split[0] {
case "unix":
@@ -47,6 +45,12 @@ func (p *program) Start(svc service.Service) error {
return fmt.Errorf("unsupported daemon address protocol: %v", split[0])
}
// Set up owner enforcement for Unix sockets.
configAdapter := &owner.ConfigAdapter{}
serverOpts := ownerServerOpts(split[0], configAdapter)
p.serv = grpc.NewServer(serverOpts...)
listen, err := net.Listen(split[0], split[1])
if err != nil {
return fmt.Errorf("listen daemon interface: %w", err)
@@ -65,6 +69,8 @@ func (p *program) Start(svc service.Service) error {
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}
configAdapter.SetBackend(serverInstance)
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
@@ -79,6 +85,32 @@ func (p *program) Start(svc service.Service) error {
return nil
}
// ownerServerOpts returns gRPC server options for owner enforcement.
// On Unix socket platforms, this includes transport credentials for peer credential
// extraction and interceptors that check the caller's UID. On other platforms or TCP,
// no owner enforcement is applied and a warning is logged so operators know the daemon
// is running without per-user authorization.
func ownerServerOpts(protocol string, configAdapter *owner.ConfigAdapter) []grpc.ServerOption {
if protocol != "unix" {
log.Warnf("daemon socket owner enforcement is not applied for protocol %q", protocol)
return nil
}
creds := owner.NewUnixTransportCredentials()
if creds == nil {
log.Warnf("daemon socket owner enforcement unavailable on this platform; daemon will accept any local connection")
return nil
}
interceptor := owner.NewInterceptor(configAdapter)
return []grpc.ServerOption{
grpc.Creds(creds),
grpc.ChainUnaryInterceptor(interceptor.UnaryInterceptor()),
grpc.ChainStreamInterceptor(interceptor.StreamInterceptor()),
}
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {

View File

@@ -11,7 +11,7 @@ import (
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err)

View File

@@ -44,6 +44,9 @@ const (
profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
claimOwnerFlag = "owner"
claimOwnerDesc = "claim owner privileges for this profile, restricting daemon control to the current user and root"
)
var (
@@ -54,6 +57,7 @@ var (
showQR bool
profileName string
configPath string
claimOwner bool
upCmd = &cobra.Command{
Use: "up",
@@ -87,6 +91,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
upCmd.PersistentFlags().BoolVar(&claimOwner, claimOwnerFlag, false, claimOwnerDesc)
}
@@ -331,6 +336,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &activeProf.Name,
Username: &username,
ClaimOwner: claimOwner,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}

View File

@@ -29,7 +29,7 @@ func TestUpDaemon(t *testing.T) {
}
sm := profilemanager.ServiceManager{}
err = sm.AddProfile("test1", currUser.Username)
err = sm.AddProfile("test1", currUser.Username, nil)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return

View File

@@ -12,6 +12,7 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -84,6 +85,12 @@ type Options struct {
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// BlockLANAccess blocks the embedded peer from reaching the host's
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
// when the embedded client must never act as a stepping stone into
// the host's local network (e.g. the proxy's overlay peer).
BlockLANAccess bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
@@ -94,6 +101,26 @@ type Options struct {
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
// Performance configures the tunnel's buffer pool cap and batch size.
Performance Performance
}
// Performance configures the embedded client's tunnel memory/throughput knobs.
//
// These settings are process-global: any non-nil field also becomes the
// default for Clients constructed by later embed.New calls in the same
// process. Nil fields are ignored.
type Performance struct {
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
// leaves the pool unbounded. Lower values trade throughput for a
// tighter memory ceiling. May also be changed on a running Client via
// Client.SetPerformance, provided this field was nonzero at construction.
PreallocatedBuffersPerPool *uint32
// MaxBatchSize overrides the number of packets the tunnel reads or
// writes per syscall, which also bounds eager buffer allocation per
// worker. Zero uses the platform default. Applied at construction
// only; ignored by Client.SetPerformance.
MaxBatchSize *uint32
}
// validateCredentials checks that exactly one credential type is provided
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
BlockLANAccess: &opts.BlockLANAccess,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
config.PrivateKey = opts.PrivateKey
}
if opts.Performance.PreallocatedBuffersPerPool != nil {
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
}
if opts.Performance.MaxBatchSize != nil {
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
}, nil
}
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false
}
state, found := c.recorder.PeerStateByIP(ip.String())
if !found {
return "", "", false
}
return state.PubKey, state.FQDN, true
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
// takes effect, and only when it was nonzero at construction;
// MaxBatchSize is construction-only and returns an error if set here.
//
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
// running yet.
func (c *Client) SetPerformance(t Performance) error {
if t.MaxBatchSize != nil {
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
}
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetPerformance(internal.Performance{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart entries from every view a previous installer may have used.
; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
return strings.HasSuffix(qname, "."+entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match

View File

@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: true,
shouldMatch: true,
},
{
name: "wildcard label-boundary mismatch (suffix overlap)",
handlerDomain: "*.b.test.",
queryDomain: "x.ab.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard label-boundary match",
handlerDomain: "*.b.test.",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard multi-label match",
handlerDomain: "*.b.test.",
queryDomain: "x.y.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on multi-label apex",
handlerDomain: "*.b.test.",
queryDomain: "b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on unrelated suffix containment",
handlerDomain: "*.example.com.",
queryDomain: "notexample.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard accepts pattern registered without trailing dot",
handlerDomain: "*.b.test",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
}
for _, tt := range tests {
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "overlapping wildcard suffixes route to correct handler",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "app.ab.test.",
expectedCalls: 1,
expectedHandler: 1,
},
{
name: "root zone with specific domain",
handlers: []struct {

View File

@@ -26,6 +26,19 @@ type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
// client knows about and whether that peer is currently connected. The
// local resolver uses this to suppress A/AAAA answers whose RDATA points
// at a disconnected peer (typical case: a synthesized private-service
// record pointing at an embedded proxy peer that just went offline).
//
// known=false means the IP isn't in the local peerstore at all — the
// record is left alone (it points at something outside our mesh, e.g.
// a non-peer upstream).
type PeerConnectivity interface {
IsConnectedByIP(ip string) (known, connected bool)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
@@ -33,6 +46,11 @@ type Resolver struct {
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
// peerConn, when non-nil, is consulted on every A/AAAA answer to
// drop records pointing at disconnected peers. nil disables the
// filter and preserves the legacy "return whatever is registered"
// behaviour for callers that never wire a status source.
peerConn PeerConnectivity
ctx context.Context
cancel context.CancelFunc
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
}
}
// SetPeerConnectivity wires the per-IP connectivity check used to filter
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
// Safe to call multiple times; the latest value wins.
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerConn = p
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
@@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
@@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
}
}
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
// a known but disconnected peer. The synthesized private-service zones
// emit one A record per connected proxy peer in a cluster; when a peer
// goes offline, the server-side refresh removes the record from the
// next netmap, but the client may still hold the previous netmap for a
// short window. This filter is the local belt to that braces — even on
// the stale netmap, the resolver hides the offline target.
//
// Records pointing at unknown IPs (outside the local peerstore, e.g.
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
// through untouched.
//
// Escape hatch: if filtering would leave the answer empty AND at least
// one record was filtered, the original list is returned. Better to
// hand the client a record that may not respond than NXDOMAIN it
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
d.mu.RLock()
checker := d.peerConn
d.mu.RUnlock()
if checker == nil {
return records
}
kept := make([]dns.RR, 0, len(records))
var dropped int
for _, rr := range records {
ip := extractRecordIP(rr)
if ip == "" {
kept = append(kept, rr)
continue
}
known, connected := checker.IsConnectedByIP(ip)
if known && !connected {
dropped++
continue
}
kept = append(kept, rr)
}
if dropped == 0 {
return records
}
if len(kept) == 0 {
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
return records
}
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
return kept
}
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
// an A or AAAA record, or "" for any other record type.
func extractRecordIP(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
if r.A == nil {
return ""
}
return r.A.String()
case *dns.AAAA:
if r.AAAA == nil {
return ""
}
return r.AAAA.String()
}
return ""
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()

View File

@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return nil, nil
}
// mockPeerConnectivity returns canned (known, connected) results per IP.
// Used by the disconnected-peer filter tests below. IPs not in the map
// are reported as unknown so the filter leaves them alone.
type mockPeerConnectivity struct {
byIP map[string]struct{ known, connected bool }
}
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
v, ok := m.byIP[ip]
if !ok {
return false, false
}
return v.known, v.connected
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver.isInManagedZone(qname)
}
}
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
// connectivity-aware filtering layered on top of lookupRecords:
// when an A record's IP belongs to a known peer that's disconnected,
// the record is dropped from the answer. Records for unknown IPs pass
// through. If filtering would empty the answer entirely and at least
// one record was dropped, the original list is restored (escape hatch
// for the "all proxies offline" case).
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
zone := "svc.cluster.netbird."
connectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.10",
}
disconnectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.11",
}
unknownRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "203.0.113.5",
}
type ipState struct{ known, connected bool }
tests := []struct {
name string
records []nbdns.SimpleRecord
connByIP map[string]ipState
wantInOrder []string
}{
{
name: "drops disconnected peer, keeps connected",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: true},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.10"},
},
{
name: "unknown IPs pass through untouched",
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"203.0.113.5"},
},
{
name: "all disconnected falls back to original list",
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: false},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
},
{
name: "no checker wired returns all records",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
if tc.connByIP != nil {
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
for ip, st := range tc.connByIP {
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
}
resolver.SetPeerConnectivity(cm)
}
resolver.Update([]nbdns.CustomZone{{
Domain: strings.TrimSuffix(zone, "."),
Records: tc.records,
NonAuthoritative: true,
}})
var got *dns.Msg
writer := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
got = m
return nil
},
}
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
resolver.ServeDNS(writer, req)
require.NotNil(t, got, "resolver must produce a response")
require.Len(t, got.Answer, len(tc.wantInOrder),
"answer count must match expected: %v", tc.wantInOrder)
for i, want := range tc.wantInOrder {
a, ok := got.Answer[i].(*dns.A)
require.True(t, ok, "answer[%d] must be an A record", i)
assert.Equal(t, want, a.A.String(),
"answer[%d] expected %s got %s", i, want, a.A.String())
}
})
}
}

View File

@@ -301,6 +301,11 @@ func newDefaultServer(
warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
// suppress A/AAAA answers that point at disconnected peers (typical
// case: synthesised private-service records pointing at an embedded
// proxy peer that just went offline).
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
}
return nil
}
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
// the local resolver can ask "is this IP a known peer and is it
// connected?" without taking on the peer package as a dependency.
// A nil status recorder always reports known=false so the resolver
// short-circuits to the legacy "return everything" path.
type localPeerConnectivity struct {
status *peer.Status
}
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
if l.status == nil {
return false, false
}
state, ok := l.status.PeerStateByIP(ip)
if !ok {
return false, false
}
return true, state.ConnStatus == peer.StatusConnected
}

View File

@@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// Performance bundles runtime-adjustable tunnel pool knobs.
// See Engine.SetPerformance. Nil fields are ignored.
type Performance struct {
PreallocatedBuffersPerPool *uint32
}
// SetPerformance applies the given tuning to this engine's live Device.
func (e *Engine) SetPerformance(t Performance) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -66,8 +66,8 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)

View File

@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
route, flags, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
switch msg.Type {
case unix.RTM_ADD:
if systemops.IgnoreAddedDefaultRoute(flags) {
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
continue
}
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
return nil, 0, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
r, err := systemops.MsgToRoute(msg)
if err != nil {
return nil, 0, err
}
return r, msg.Flags, nil
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.

View File

@@ -0,0 +1,46 @@
package owner
import (
"fmt"
"sync"
)
// ConfigAdapter is a thread-safe OwnerConfig that delegates to a lazily-set backend.
// This allows the interceptor to be created before the daemon server (and its config)
// is initialized, which is necessary because gRPC interceptors are set at server creation time.
type ConfigAdapter struct {
mu sync.RWMutex
backend OwnerConfig
}
// SetBackend sets the actual config implementation. Must be called before any RPCs are served.
func (a *ConfigAdapter) SetBackend(backend OwnerConfig) {
a.mu.Lock()
defer a.mu.Unlock()
a.backend = backend
}
// GetOwnerUIDs delegates to the backend.
func (a *ConfigAdapter) GetOwnerUIDs() []UID {
a.mu.RLock()
defer a.mu.RUnlock()
if a.backend == nil {
// No backend yet, return empty (root-only).
return []UID{}
}
return a.backend.GetOwnerUIDs()
}
// AddOwnerUID delegates to the backend.
func (a *ConfigAdapter) AddOwnerUID(uid UID) error {
a.mu.RLock()
defer a.mu.RUnlock()
if a.backend == nil {
return fmt.Errorf("owner config backend not initialized")
}
return a.backend.AddOwnerUID(uid)
}

View File

@@ -0,0 +1,17 @@
// Package consoleuser provides the OS-level "active console user" UID lookup
// used to gate ownership TOFU. The active console user is the local user
// physically at the machine (or in the foreground GUI session): the user that
// can legitimately claim the daemon as theirs on first run.
package consoleuser
// ActiveUID returns the UID of the currently active console / GUI session
// user, and true if such a user exists. Returns 0, false on platforms without
// a console concept (ios, android), on headless servers with no active
// session, or on lookup failure.
//
// Implementations must fail closed: any error or ambiguity returns (0, false)
// so that the caller treats the result as "no console user" rather than
// granting access to an unverified UID.
func ActiveUID() (uint32, bool) {
return activeUID()
}

View File

@@ -0,0 +1,58 @@
package consoleuser
import (
"unsafe"
"github.com/ebitengine/purego"
)
// activeUID returns the UID of the user currently logged into the macOS GUI
// console session. Uses SCDynamicStoreCopyConsoleUser from the
// SystemConfiguration framework via purego (no cgo).
func activeUID() (uint32, bool) {
sc, err := purego.Dlopen(
"/System/Library/Frameworks/SystemConfiguration.framework/SystemConfiguration",
purego.RTLD_NOW|purego.RTLD_GLOBAL,
)
if err != nil {
return 0, false
}
cf, err := purego.Dlopen(
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
purego.RTLD_NOW|purego.RTLD_GLOBAL,
)
if err != nil {
return 0, false
}
// CFStringRef SCDynamicStoreCopyConsoleUser(SCDynamicStoreRef store,
// uid_t *uid, gid_t *gid);
//
// We pass nil for the store (NULL is accepted; the framework creates a
// transient one), discard the returned CFStringRef username (we only
// need the UID), and read uid via the out-pointer.
var copyConsoleUser func(store uintptr, uidPtr, gidPtr unsafe.Pointer) uintptr
purego.RegisterLibFunc(&copyConsoleUser, sc, "SCDynamicStoreCopyConsoleUser")
var cfRelease func(uintptr)
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
var uid uint32
var gid uint32
cfStr := copyConsoleUser(0, unsafe.Pointer(&uid), unsafe.Pointer(&gid))
if cfStr == 0 {
return 0, false
}
cfRelease(cfStr)
// loginwindow / no GUI session reports uid 0. We don't want the
// console-user path to grant anything to root (root is already always
// allowed by the interceptor), so treat uid 0 as "no console user".
if uid == 0 {
return 0, false
}
return uid, true
}

View File

@@ -0,0 +1,34 @@
package consoleuser
import (
"fmt"
"os"
"syscall"
)
// activeUID returns the UID of the user currently logged into the FreeBSD
// console. FreeBSD's vt(4) chowns the active virtual terminal device to the
// logged-in user, so a non-root owner of any /dev/ttyvN reliably identifies
// the console user.
//
// We scan /dev/ttyv0../dev/ttyv9 and return the first non-root owner. Network
// ptys (pts) are intentionally not considered: SSH'd users are not "at the
// console" and must not TOFU-claim ownership.
func activeUID() (uint32, bool) {
for i := 0; i < 10; i++ {
path := fmt.Sprintf("/dev/ttyv%d", i)
fi, err := os.Stat(path)
if err != nil {
continue
}
st, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
continue
}
if st.Uid == 0 {
continue
}
return st.Uid, true
}
return 0, false
}

View File

@@ -0,0 +1,64 @@
package consoleuser
import (
"github.com/godbus/dbus/v5"
)
const (
loginDest = "org.freedesktop.login1"
loginPath = dbus.ObjectPath("/org/freedesktop/login1")
loginInterface = "org.freedesktop.login1.Manager"
listSessions = loginInterface + ".ListSessions"
sessionInterface = "org.freedesktop.login1.Session"
sessionActive = sessionInterface + ".Active"
sessionClass = sessionInterface + ".Class"
)
// activeUID queries systemd-logind for the active local user session and
// returns that user's UID. Falls back to (0, false) on any error or when no
// active user session exists (headless box, no GUI, no login at the console).
func activeUID() (uint32, bool) {
conn, err := dbus.SystemBus()
if err != nil {
return 0, false
}
mgr := conn.Object(loginDest, loginPath)
// ListSessions returns []struct{ID string; UID uint32; User string;
// Seat string; Path dbus.ObjectPath}.
var sessions []struct {
ID string
UID uint32
User string
Seat string
Path dbus.ObjectPath
}
if err := mgr.Call(listSessions, 0).Store(&sessions); err != nil {
return 0, false
}
for _, s := range sessions {
obj := conn.Object(loginDest, s.Path)
active, err := obj.GetProperty(sessionActive)
if err != nil || active.Value() != true {
continue
}
class, err := obj.GetProperty(sessionClass)
if err != nil {
continue
}
// Only "user" sessions count; "greeter" / "lock-screen" / etc. are
// not someone we should grant ownership to.
if classStr, ok := class.Value().(string); !ok || classStr != "user" {
continue
}
return s.UID, true
}
return 0, false
}

View File

@@ -0,0 +1,9 @@
//go:build !linux && !darwin && !freebsd && !windows
package consoleuser
// activeUID has no meaning on platforms without a console-user concept
// (ios, android). Returns no-user so TOFU never fires.
func activeUID() (uint32, bool) {
return 0, false
}

View File

@@ -0,0 +1,59 @@
package consoleuser
import (
"unsafe"
"golang.org/x/sys/windows"
)
// activeUID returns a synthetic UID (the user SID's RID) for the currently
// active Windows console session. The owner package treats UIDs as opaque
// uint32 identifiers; on Windows we use the user account RID, which is stable
// per-account on a given machine.
//
// Returns (0, false) when there is no active console session, the session has
// no logged-in user, or any lookup fails.
func activeUID() (uint32, bool) {
sessionID := windows.WTSGetActiveConsoleSessionId()
if sessionID == 0xFFFFFFFF {
return 0, false
}
var token windows.Token
if err := windows.WTSQueryUserToken(sessionID, &token); err != nil {
return 0, false
}
defer token.Close()
user, err := tokenUserSID(token)
if err != nil || user == nil {
return 0, false
}
subCount := user.SubAuthorityCount()
if subCount == 0 {
return 0, false
}
rid := user.SubAuthority(uint32(subCount) - 1)
if rid == 0 {
return 0, false
}
return rid, true
}
// tokenUserSID returns the user SID associated with the given access token.
func tokenUserSID(token windows.Token) (*windows.SID, error) {
var size uint32
err := windows.GetTokenInformation(token, windows.TokenUser, nil, 0, &size)
if err != windows.ERROR_INSUFFICIENT_BUFFER {
return nil, err
}
buf := make([]byte, size)
if err := windows.GetTokenInformation(token, windows.TokenUser, &buf[0], size, &size); err != nil {
return nil, err
}
tu := (*windows.Tokenuser)(unsafe.Pointer(&buf[0]))
return tu.User.Sid, nil
}

View File

@@ -0,0 +1,37 @@
package owner
import (
"context"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
)
// UnixAuthInfo implements credentials.AuthInfo carrying the peer's UID from SO_PEERCRED.
type UnixAuthInfo struct {
credentials.CommonAuthInfo
UID UID
GID uint32
PID int32
}
// AuthType returns the authentication type.
func (u UnixAuthInfo) AuthType() string {
return "unix_peercred"
}
// UIDFromContext extracts the caller's UID from the gRPC peer context.
// Returns uid and true if Unix credentials were available, 0 and false otherwise.
func UIDFromContext(ctx context.Context) (UID, bool) {
p, ok := peer.FromContext(ctx)
if !ok {
return 0, false
}
info, ok := p.AuthInfo.(UnixAuthInfo)
if !ok {
return 0, false
}
return info.UID, true
}

View File

@@ -0,0 +1,48 @@
package owner
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
// EnvOwnerUID is the environment variable that seeds the owner UID list for new config files.
// MDM deployments can set this (e.g. via --service-env NB_OWNER_UID=1000) so the first
// config created by the daemon pre-populates the owner without requiring "netbird up --owner".
// Multiple UIDs can be comma-separated: NB_OWNER_UID=1000,1001
const EnvOwnerUID = "NB_OWNER_UID"
// OwnerUIDsFromEnv parses NB_OWNER_UID into a UID slice.
// Returns nil if the variable is unset, allowing the caller to distinguish
// "not configured" from "explicitly empty".
func OwnerUIDsFromEnv() []UID {
val := os.Getenv(EnvOwnerUID)
if val == "" {
return nil
}
parts := strings.Split(val, ",")
uids := make([]UID, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
uid, err := strconv.ParseUint(p, 10, 32)
if err != nil {
log.Warnf("ignoring invalid UID %q in %s: %v", p, EnvOwnerUID, err)
continue
}
uids = append(uids, UID(uid))
}
if len(uids) == 0 {
log.Warnf("%s set but contains no valid UIDs, defaulting to root-only", EnvOwnerUID)
return []UID{}
}
log.Infof("seeding owner UIDs from %s: %v", EnvOwnerUID, uids)
return uids
}

View File

@@ -0,0 +1,81 @@
package owner
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOwnerUIDsFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
unset bool
want []UID
}{
{
name: "unset returns nil",
unset: true,
want: nil,
},
{
name: "empty string returns nil",
envValue: "",
want: nil,
},
{
name: "single UID",
envValue: "1000",
want: []UID{1000},
},
{
name: "multiple UIDs",
envValue: "1000,1001,1002",
want: []UID{1000, 1001, 1002},
},
{
name: "spaces around UIDs",
envValue: " 1000 , 1001 ",
want: []UID{1000, 1001},
},
{
name: "invalid UID skipped",
envValue: "1000,notanumber,1001",
want: []UID{1000, 1001},
},
{
name: "all invalid returns empty slice",
envValue: "abc,def",
want: []UID{},
},
{
name: "trailing comma",
envValue: "1000,",
want: []UID{1000},
},
{
name: "zero UID is valid",
envValue: "0",
want: []UID{0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv(EnvOwnerUID, tt.envValue)
if tt.unset {
os.Unsetenv(EnvOwnerUID)
}
got := OwnerUIDsFromEnv()
if tt.want == nil {
require.Nil(t, got)
} else {
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@@ -0,0 +1,170 @@
package owner
import (
"context"
"slices"
"sync"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/owner/consoleuser"
)
const servicePath = "/daemon.DaemonService/"
// profileBypassMethods skip the active-profile owner check. They either
// operate on a specific target profile (and the handler enforces target-profile
// owner-or-root itself) or are per-user listings/creations that don't affect
// the active session and shouldn't require active-profile ownership. Peer
// credentials are still required.
var profileBypassMethods = map[string]bool{
servicePath + "AddProfile": true,
servicePath + "ListProfiles": true,
servicePath + "RemoveProfile": true,
servicePath + "SwitchProfile": true,
}
// Error messages returned to denied callers. They are multi-line so the
// suggested commands sit on their own line for easy triple-click copy-paste.
const (
errNoPeerCreds = "peer credentials unavailable; rerun via the netbird CLI"
errNoOwnerConfigured = `no daemon owner is configured and no console-session user matches your UID.
Run as root for one-off use:
sudo netbird ...
Or call from the active console session: the first call from the user logged in
at the GUI/console claims ownership automatically.`
errOwnerRequired = `this operation requires root or the daemon owner (uid %d is not an owner).
Run as root for one-off use:
sudo netbird ...
Or ask an existing owner (or root) to add you:
sudo netbird owner add %[1]d`
)
// consoleUIDLookup is the function used to look up the active console UID.
// Overridable in tests; defaults to the platform implementation.
var consoleUIDLookup = consoleuser.ActiveUID
// OwnerConfig provides access to the current owner UIDs setting.
// The interceptor reads and writes through this interface so it can
// work with the profile manager's config without a direct dependency.
type OwnerConfig interface {
// GetOwnerUIDs returns the current owner UIDs.
// nil means legacy/migration TOFU (field absent from existing config).
// empty means fresh install (root-only with console-user TOFU exception).
// populated means those UIDs plus root may control the daemon.
GetOwnerUIDs() []UID
// AddOwnerUID adds the given UID to the owner list and persists it.
AddOwnerUID(uid UID) error
}
// Interceptor enforces owner restrictions on the daemon gRPC socket.
type Interceptor struct {
config OwnerConfig
// mu serializes the read-then-write of OwnerUIDs during TOFU/claim flows
// so two concurrent first-callers can't both end up persisted as owners.
// Holds across the OwnerConfig.AddOwnerUID call; safe because no callback
// path takes this mutex.
mu sync.Mutex
}
// NewInterceptor creates an owner interceptor backed by the given config.
func NewInterceptor(config OwnerConfig) *Interceptor {
return &Interceptor{config: config}
}
// UnaryInterceptor returns a gRPC unary server interceptor that enforces owner policy.
func (i *Interceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
if err := i.authorize(ctx, info.FullMethod); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// StreamInterceptor returns a gRPC stream server interceptor that enforces owner policy.
func (i *Interceptor) StreamInterceptor() grpc.StreamServerInterceptor {
return func(
srv any,
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
if err := i.authorize(ss.Context(), info.FullMethod); err != nil {
return err
}
return handler(srv, ss)
}
}
// authorize checks whether the caller is allowed to call the given method.
// Every RPC is gated; root is always allowed. Non-root callers are accepted
// when they are existing owners, when the config is in legacy TOFU state
// (claim on first call, preserves pre-enforcement behavior), or when the
// config is in fresh-install state and they match the active console user.
func (i *Interceptor) authorize(ctx context.Context, fullMethod string) error {
uid, ok := UIDFromContext(ctx)
if !ok {
return status.Error(codes.PermissionDenied, errNoPeerCreds)
}
if uid == 0 {
return nil
}
// Profile-management RPCs do their own per-target authorization in the
// handler. The interceptor only confirms peer credentials are present.
if profileBypassMethods[fullMethod] {
return nil
}
i.mu.Lock()
defer i.mu.Unlock()
ownerUIDs := i.config.GetOwnerUIDs()
switch {
case ownerUIDs == nil:
// Legacy / migration TOFU: existing pre-enforcement config has no
// owners field. Any non-root local caller claims on first call so
// upgrades don't break.
return i.claim(uid, "migration TOFU")
case len(ownerUIDs) == 0:
// Fresh-install root-only mode with a console-user exception so the
// GUI/CLI just works for the user physically at the machine. SSH'd
// or otherwise non-console callers are denied.
consoleUID, ok := consoleUIDLookup()
if ok && uint32(uid) == consoleUID {
return i.claim(uid, "console-user TOFU")
}
return status.Error(codes.PermissionDenied, errNoOwnerConfigured)
case slices.Contains(ownerUIDs, uid):
return nil
default:
return status.Errorf(codes.PermissionDenied, errOwnerRequired, uid)
}
}
// claim adds uid to the owner list and persists it. The caller must hold i.mu.
func (i *Interceptor) claim(uid UID, reason string) error {
log.Infof("%s: claiming owner for UID %d", reason, uid)
if err := i.config.AddOwnerUID(uid); err != nil {
log.Errorf("persist owner UID: %v", err)
return status.Error(codes.Internal, "persist owner UID")
}
return nil
}

View File

@@ -0,0 +1,277 @@
package owner
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
type mockOwnerConfig struct {
uids []UID
err error
}
func (m *mockOwnerConfig) GetOwnerUIDs() []UID {
return m.uids
}
func (m *mockOwnerConfig) AddOwnerUID(uid UID) error {
if m.err != nil {
return m.err
}
m.uids = append(m.uids, uid)
return nil
}
func peerContext(uid UID) context.Context {
return peer.NewContext(context.Background(), &peer.Peer{
Addr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
AuthInfo: UnixAuthInfo{
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity},
UID: uid,
},
})
}
func noPeerContext() context.Context {
return context.Background()
}
// withConsoleUID overrides the platform console-user lookup for a single test.
func withConsoleUID(t *testing.T, uid uint32, ok bool) {
t.Helper()
prev := consoleUIDLookup
consoleUIDLookup = func() (uint32, bool) { return uid, ok }
t.Cleanup(func() { consoleUIDLookup = prev })
}
func TestInterceptor_RootAlwaysAllowed(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
for _, method := range []string{
"/daemon.DaemonService/Up",
"/daemon.DaemonService/Status",
"/daemon.DaemonService/Down",
} {
err := interceptor.authorize(peerContext(0), method)
assert.NoError(t, err, "root should always be allowed for %s", method)
}
}
func TestInterceptor_NoPeerCreds_AlwaysDenies(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
for _, method := range []string{
"/daemon.DaemonService/Status",
"/daemon.DaemonService/Up",
"/daemon.DaemonService/SomeNewMethod",
} {
err := interceptor.authorize(noPeerContext(), method)
require.Error(t, err, "method %s should be denied without peer creds", method)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
}
// TestInterceptor_LegacyMigration covers the nil-OwnerUIDs branch:
// pre-enforcement configs upgraded to this version. Any non-root local caller
// can claim on first call.
func TestInterceptor_LegacyMigration_AnyCallerClaims(t *testing.T) {
withConsoleUID(t, 0, false) // no console; should not matter for nil
cfg := &mockOwnerConfig{uids: nil}
interceptor := NewInterceptor(cfg)
// First call from any UID claims regardless of method.
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Status")
require.NoError(t, err)
require.Equal(t, []UID{1000}, cfg.uids)
// After claim, a different UID is denied.
err = interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Status")
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
// TestInterceptor_FreshInstall covers the empty-OwnerUIDs branch: console-user
// can claim, others denied.
func TestInterceptor_FreshInstall_ConsoleUserClaims(t *testing.T) {
withConsoleUID(t, 1000, true)
cfg := &mockOwnerConfig{uids: []UID{}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Status")
require.NoError(t, err)
require.Equal(t, []UID{1000}, cfg.uids)
}
func TestInterceptor_FreshInstall_NonConsoleDenied(t *testing.T) {
withConsoleUID(t, 1000, true)
cfg := &mockOwnerConfig{uids: []UID{}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Up")
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
assert.Empty(t, cfg.uids, "non-console caller must not claim")
}
func TestInterceptor_FreshInstall_NoConsole_Denied(t *testing.T) {
withConsoleUID(t, 0, false)
cfg := &mockOwnerConfig{uids: []UID{}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Up")
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
func TestInterceptor_OwnerUID_AllowsOwner(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Down")
assert.NoError(t, err)
}
func TestInterceptor_OwnerUID_DeniesOther(t *testing.T) {
withConsoleUID(t, 9999, true) // console-user TOFU should not apply once owners exist
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Down")
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
func TestInterceptor_MultipleOwners(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000, 2000}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Down")
assert.NoError(t, err)
err = interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Up")
assert.NoError(t, err)
err = interceptor.authorize(peerContext(3000), "/daemon.DaemonService/Down")
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
// TestInterceptor_UnknownMethodRequiresOwner pins the safe-by-default invariant:
// any future RPC still goes through owner enforcement.
func TestInterceptor_UnknownMethodRequiresOwner(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/SomeFutureMethod")
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
err = interceptor.authorize(peerContext(1000), "/daemon.DaemonService/SomeFutureMethod")
assert.NoError(t, err)
}
func TestInterceptor_ErrorMessageActionable(t *testing.T) {
withConsoleUID(t, 9999, true)
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Down")
require.Error(t, err)
msg := status.Convert(err).Message()
assert.Contains(t, msg, "sudo netbird")
assert.Contains(t, msg, "owner add")
}
func TestInterceptor_UnaryIntegration(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
unary := interceptor.UnaryInterceptor()
resp, err := unary(peerContext(1000), nil, &grpc.UnaryServerInfo{FullMethod: "/daemon.DaemonService/Down"}, func(ctx context.Context, req any) (any, error) {
return "ok", nil
})
require.NoError(t, err)
assert.Equal(t, "ok", resp)
_, err = unary(peerContext(2000), nil, &grpc.UnaryServerInfo{FullMethod: "/daemon.DaemonService/Down"}, func(ctx context.Context, req any) (any, error) {
t.Fatal("handler should not be called")
return nil, nil
})
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
func TestInterceptor_StreamIntegration(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
stream := interceptor.StreamInterceptor()
called := false
err := stream(nil, &mockServerStream{ctx: peerContext(1000)},
&grpc.StreamServerInfo{FullMethod: "/daemon.DaemonService/SubscribeEvents"},
func(srv any, stream grpc.ServerStream) error {
called = true
return nil
})
require.NoError(t, err)
assert.True(t, called)
err = stream(nil, &mockServerStream{ctx: peerContext(2000)},
&grpc.StreamServerInfo{FullMethod: "/daemon.DaemonService/SubscribeEvents"},
func(srv any, stream grpc.ServerStream) error {
t.Fatal("handler should not be called")
return nil
})
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
type mockServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (m *mockServerStream) Context() context.Context { return m.ctx }
// TestInterceptor_ProfileBypass pins that profile-management methods reach
// the handler regardless of active-profile ownership; the handler enforces
// per-target-profile auth itself.
func TestInterceptor_ProfileBypass(t *testing.T) {
cfg := &mockOwnerConfig{uids: []UID{1000}}
interceptor := NewInterceptor(cfg)
// Caller UID 2000 is not an owner of the active profile but must be
// allowed through for these methods.
for _, method := range []string{
"/daemon.DaemonService/AddProfile",
"/daemon.DaemonService/ListProfiles",
"/daemon.DaemonService/RemoveProfile",
"/daemon.DaemonService/SwitchProfile",
} {
err := interceptor.authorize(peerContext(2000), method)
assert.NoError(t, err, "profile method %s should bypass active-owner check", method)
}
// Without peer creds, even bypass methods are denied.
for _, method := range []string{
"/daemon.DaemonService/AddProfile",
"/daemon.DaemonService/SwitchProfile",
} {
err := interceptor.authorize(noPeerContext(), method)
require.Error(t, err, "bypass method %s still requires peer creds", method)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
}
}

View File

@@ -0,0 +1,66 @@
//go:build darwin || freebsd
package owner
import (
"context"
"fmt"
"net"
"golang.org/x/sys/unix"
"google.golang.org/grpc/credentials"
)
// NewUnixTransportCredentials returns gRPC TransportCredentials that extract
// peer UID from Unix socket connections via LOCAL_PEERCRED (Xucred).
func NewUnixTransportCredentials() credentials.TransportCredentials {
return &unixCreds{}
}
type unixCreds struct{}
func (c *unixCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, UnixAuthInfo{}, nil
}
// ServerHandshake extracts peer credentials from the Unix connection using LOCAL_PEERCRED.
// Returns an error if credentials cannot be extracted (fail-closed).
func (c *unixCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
uc, ok := conn.(*net.UnixConn)
if !ok {
return nil, nil, fmt.Errorf("expected *net.UnixConn, got %T", conn)
}
raw, err := uc.SyscallConn()
if err != nil {
return nil, nil, fmt.Errorf("get raw conn for peer credentials: %w", err)
}
var xucred *unix.Xucred
var credErr error
if err := raw.Control(func(fd uintptr) {
xucred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
}); err != nil {
return nil, nil, fmt.Errorf("control raw conn for peer credentials: %w", err)
}
if credErr != nil {
return nil, nil, fmt.Errorf("get peer credentials: %w", credErr)
}
return conn, UnixAuthInfo{
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity},
UID: UID(xucred.Uid),
}, nil
}
func (c *unixCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{SecurityProtocol: "unix_peercred"}
}
func (c *unixCreds) Clone() credentials.TransportCredentials {
return &unixCreds{}
}
func (c *unixCreds) OverrideServerName(_ string) error {
return nil
}

View File

@@ -0,0 +1,11 @@
//go:build !linux && !darwin && !freebsd
package owner
import "google.golang.org/grpc/credentials"
// NewUnixTransportCredentials returns nil on platforms without Unix socket peer credentials.
// The daemon should use insecure credentials and skip owner enforcement.
func NewUnixTransportCredentials() credentials.TransportCredentials {
return nil
}

View File

@@ -0,0 +1,66 @@
package owner
import (
"context"
"fmt"
"net"
"golang.org/x/sys/unix"
"google.golang.org/grpc/credentials"
)
// NewUnixTransportCredentials returns gRPC TransportCredentials that extract
// peer UID/GID/PID from Unix socket connections via SO_PEERCRED.
func NewUnixTransportCredentials() credentials.TransportCredentials {
return &unixCreds{}
}
type unixCreds struct{}
func (c *unixCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, UnixAuthInfo{}, nil
}
// ServerHandshake extracts peer credentials from the Unix connection.
// Returns an error if credentials cannot be extracted (fail-closed).
func (c *unixCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
uc, ok := conn.(*net.UnixConn)
if !ok {
return nil, nil, fmt.Errorf("expected *net.UnixConn, got %T", conn)
}
raw, err := uc.SyscallConn()
if err != nil {
return nil, nil, fmt.Errorf("get raw conn for peer credentials: %w", err)
}
var ucred *unix.Ucred
var credErr error
if err := raw.Control(func(fd uintptr) {
ucred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
}); err != nil {
return nil, nil, fmt.Errorf("control raw conn for peer credentials: %w", err)
}
if credErr != nil {
return nil, nil, fmt.Errorf("get peer credentials: %w", credErr)
}
return conn, UnixAuthInfo{
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity},
UID: UID(ucred.Uid),
GID: ucred.Gid,
PID: ucred.Pid,
}, nil
}
func (c *unixCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{SecurityProtocol: "unix_peercred"}
}
func (c *unixCreds) Clone() credentials.TransportCredentials {
return &unixCreds{}
}
func (c *unixCreds) OverrideServerName(_ string) error {
return nil
}

View File

@@ -0,0 +1,107 @@
package owner
import (
"net"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/credentials"
)
func TestUnixTransportCredentials_ServerHandshake(t *testing.T) {
creds := NewUnixTransportCredentials()
if creds == nil {
t.Skip("unix transport credentials not supported on this platform")
}
sockPath := filepath.Join(t.TempDir(), "test.sock")
ln, err := net.Listen("unix", sockPath)
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
done := make(chan struct{})
var serverConn net.Conn
var serverAuth credentials.AuthInfo
var serverErr error
go func() {
defer close(done)
raw, err := ln.Accept()
if err != nil {
serverErr = err
return
}
serverConn, serverAuth, serverErr = creds.ServerHandshake(raw)
}()
client, err := net.Dial("unix", sockPath)
require.NoError(t, err)
t.Cleanup(func() { client.Close() })
<-done
require.NoError(t, serverErr)
require.NotNil(t, serverConn)
t.Cleanup(func() { serverConn.Close() })
authInfo, ok := serverAuth.(UnixAuthInfo)
require.True(t, ok, "expected UnixAuthInfo, got %T", serverAuth)
assert.Equal(t, UID(os.Getuid()), authInfo.UID, "UID should match current user")
}
func TestUnixTransportCredentials_ServerHandshake_NonUnixConn(t *testing.T) {
creds := NewUnixTransportCredentials()
if creds == nil {
t.Skip("unix transport credentials not supported on this platform")
}
// Use a TCP connection, which is not *net.UnixConn.
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
done := make(chan struct{})
var handshakeErr error
go func() {
defer close(done)
raw, err := ln.Accept()
if err != nil {
handshakeErr = err
return
}
defer raw.Close()
_, _, handshakeErr = creds.ServerHandshake(raw)
}()
client, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
t.Cleanup(func() { client.Close() })
<-done
require.Error(t, handshakeErr, "ServerHandshake must fail for non-Unix connections")
}
func TestUnixTransportCredentials_Info(t *testing.T) {
creds := NewUnixTransportCredentials()
if creds == nil {
t.Skip("unix transport credentials not supported on this platform")
}
info := creds.Info()
assert.Equal(t, "unix_peercred", info.SecurityProtocol)
}
func TestUnixTransportCredentials_Clone(t *testing.T) {
creds := NewUnixTransportCredentials()
if creds == nil {
t.Skip("unix transport credentials not supported on this platform")
}
cloned := creds.Clone()
require.NotNil(t, cloned)
assert.Equal(t, creds.Info(), cloned.Info())
}

View File

@@ -0,0 +1,5 @@
package owner
// UID is a Unix user ID. Defined as a distinct type so it can't be silently
// swapped with GID, PID, or other uint32 values at call sites.
type UID uint32

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays
// Status holds a state of peers, signal, management connections and relays.
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
// every private-service request) don't contend against each other.
// Pure read methods take RLock; anything that mutates state takes Lock.
type Status struct {
mux sync.Mutex
mux sync.RWMutex
peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
// GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (State, error) {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
state, ok := d.peers[peerPubKey]
if !ok {
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
}
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if state.IP == ip {
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
return "", false
}
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return d.localPeer.Clone()
}
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetLazyConnection() bool {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
// if peer is connected to the management then login is not expired
if d.managementState {
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.ingressGwMgr == nil {
return nil
}
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
return maps.Clone(d.resolvedDomainsStates)
}
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
LazyConnectionEnabled: d.GetLazyConnection(),
}
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
fullStatus.LocalPeerState = d.localPeer
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.Lock()
defer d.mux.Unlock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}

View File

@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal")
}
func TestStatus_PeerStateByIP(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
state, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "known tunnel IP should resolve to a peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
_, ok = status.PeerStateByIP("100.64.0.99")
req.False(ok, "unknown IP must report ok=false")
}
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
state, ok := status.PeerStateByIP("fd00::1")
req.True(ok, "IPv6-only match must resolve to the peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -21,6 +21,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/owner"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
mgm "github.com/netbirdio/netbird/shared/management/client"
@@ -99,6 +100,10 @@ type ConfigInput struct {
LazyConnectionEnabled *bool
MTU *uint16
// OwnerUIDs sets the UIDs of users allowed to control the daemon.
// When non-nil, replaces the config's OwnerUIDs.
OwnerUIDs []owner.UID
}
// Config Configuration type
@@ -174,6 +179,12 @@ type Config struct {
LazyConnectionEnabled bool
MTU uint16
// OwnerUIDs controls who can perform privileged daemon operations via the gRPC socket.
// nil (absent from JSON): TOFU mode, first privileged caller claims ownership (backward compat for existing installs).
// [] (empty slice): root-only, no non-root owners until explicitly set via "netbird up --owner".
// [uid1, uid2, ...]: these UIDs plus root can perform privileged operations.
OwnerUIDs []owner.UID `json:"OwnerUIDs"`
}
var ConfigDirOverride string
@@ -234,10 +245,18 @@ func fileExists(path string) (bool, error) {
// createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) {
// Seed owner UIDs from environment if set (for MDM deployments),
// otherwise default to root-only (empty slice).
ownerUIDs := owner.OwnerUIDsFromEnv()
if ownerUIDs == nil {
ownerUIDs = []owner.UID{}
}
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
WgPort: iface.DefaultWgPort,
OwnerUIDs: ownerUIDs,
}
if _, err := config.apply(input); err != nil {
@@ -612,6 +631,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.OwnerUIDs != nil {
if !slices.Equal(config.OwnerUIDs, input.OwnerUIDs) {
log.Infof("updating owner UIDs to %v", input.OwnerUIDs)
config.OwnerUIDs = input.OwnerUIDs
updated = true
}
}
return updated, nil
}

View File

@@ -13,6 +13,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/owner"
"github.com/netbirdio/netbird/util"
)
@@ -243,7 +244,10 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
// AddProfile creates a new profile with the given name. inheritOwnerUIDs is
// applied to the new profile's OwnerUIDs (pass the active profile's owners so
// the caller stays authorized; pass nil to leave the default empty/env-seeded).
func (s *ServiceManager) AddProfile(profileName, username string, inheritOwnerUIDs []owner.UID) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
@@ -264,7 +268,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
return ErrProfileAlreadyExists
}
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath, OwnerUIDs: inheritOwnerUIDs})
if err != nil {
return fmt.Errorf("failed to create new config: %w", err)
}

View File

@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil))
}
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct {
ifaceName string
spk []byte
@@ -36,7 +45,7 @@ type Manager struct {
preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler
server *rp.Server
server rpServer
lock sync.Mutex
port int
wgIface PresharedKeySetter
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
}
func (m *Manager) GetPubKey() []byte {
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil {
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
}
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
}
peerID, err := m.server.AddPeer(pcfg)
if err != nil {
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
return err
}
m.server, err = rp.NewUDPServer(conf)
server, err := rp.NewUDPServer(conf)
if err != nil {
return err
}
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port)
return m.server.Run()
return server.Run()
}
// Close closes the Rosenpass server
func (m *Manager) Close() error {
if m.server != nil {
err := m.server.Close()
if err != nil {
log.Errorf("failed closing local rosenpass server")
}
m.server = nil
m.lock.Lock()
server := m.server
m.server = nil
m.lock.Unlock()
if server == nil {
return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
}
return nil
}

View File

@@ -1,14 +1,412 @@
package rosenpass
import (
"errors"
"os"
"sync"
"testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -0,0 +1,42 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -0,0 +1,44 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -0,0 +1,9 @@
//go:build dragonfly || freebsd || netbsd || openbsd
package systemops
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor.
func IgnoreAddedDefaultRoute(flags int) bool {
return filterRoutesByFlags(flags)
}

View File

@@ -0,0 +1,21 @@
//go:build darwin
package systemops
import "golang.org/x/sys/unix"
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor. Scoped routes
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
// must not trigger an engine restart.
func IgnoreAddedDefaultRoute(flags int) bool {
if filterRoutesByFlags(flags) {
return true
}
if flags&unix.RTF_IFSCOPE != 0 {
return true
}
return false
}

View File

@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
}
m.mu.Lock()
defer m.mu.Unlock()
cancel := m.cancel
done := m.done
m.mu.Unlock()
if m.cancel == nil {
if cancel == nil {
return nil
}
m.cancel()
cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
case <-done:
}
return nil

View File

@@ -64,13 +64,6 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -83,28 +76,10 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

File diff suppressed because it is too large Load Diff

View File

@@ -91,6 +91,15 @@ service DaemonService {
rpc GetActiveProfile(GetActiveProfileRequest) returns (GetActiveProfileResponse) {}
// AddOwner adds a UID to the active profile's owner list. Requires
// root or an existing owner.
rpc AddOwner(AddOwnerRequest) returns (AddOwnerResponse) {}
// ResetOwner clears the active profile's owner list, returning it to
// the unconfigured state. The next call from the active console-session
// user will then re-claim ownership. Requires root.
rpc ResetOwner(ResetOwnerRequest) returns (ResetOwnerResponse) {}
// Logout disconnects from the network and deletes the peer from the management server
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
@@ -227,6 +236,10 @@ message UpRequest {
optional string profileName = 1;
optional string username = 2;
reserved 3;
// When true, the caller claims owner privileges for this profile.
// Requires root or current owner; for new installs (root-only mode),
// the calling UID becomes an owner.
bool claimOwner = 4;
}
message UpResponse {}
@@ -689,6 +702,16 @@ message AddProfileRequest {
message AddProfileResponse {}
message AddOwnerRequest {
uint32 uid = 1;
}
message AddOwnerResponse {}
message ResetOwnerRequest {}
message ResetOwnerResponse {}
message RemoveProfileRequest {
string username = 1;
string profileName = 2;

View File

@@ -48,6 +48,8 @@ const (
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
DaemonService_AddOwner_FullMethodName = "/daemon.DaemonService/AddOwner"
DaemonService_ResetOwner_FullMethodName = "/daemon.DaemonService/ResetOwner"
DaemonService_Logout_FullMethodName = "/daemon.DaemonService/Logout"
DaemonService_GetFeatures_FullMethodName = "/daemon.DaemonService/GetFeatures"
DaemonService_TriggerUpdate_FullMethodName = "/daemon.DaemonService/TriggerUpdate"
@@ -115,6 +117,13 @@ type DaemonServiceClient interface {
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
// AddOwner adds a UID to the active profile's owner list. Requires
// root or an existing owner.
AddOwner(ctx context.Context, in *AddOwnerRequest, opts ...grpc.CallOption) (*AddOwnerResponse, error)
// ResetOwner clears the active profile's owner list, returning it to
// the unconfigured state. The next call from the active console-session
// user will then re-claim ownership. Requires root.
ResetOwner(ctx context.Context, in *ResetOwnerRequest, opts ...grpc.CallOption) (*ResetOwnerResponse, error)
// Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
@@ -452,6 +461,26 @@ func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiv
return out, nil
}
func (c *daemonServiceClient) AddOwner(ctx context.Context, in *AddOwnerRequest, opts ...grpc.CallOption) (*AddOwnerResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AddOwnerResponse)
err := c.cc.Invoke(ctx, DaemonService_AddOwner_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) ResetOwner(ctx context.Context, in *ResetOwnerRequest, opts ...grpc.CallOption) (*ResetOwnerResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ResetOwnerResponse)
err := c.cc.Invoke(ctx, DaemonService_ResetOwner_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(LogoutResponse)
@@ -616,6 +645,13 @@ type DaemonServiceServer interface {
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
// AddOwner adds a UID to the active profile's owner list. Requires
// root or an existing owner.
AddOwner(context.Context, *AddOwnerRequest) (*AddOwnerResponse, error)
// ResetOwner clears the active profile's owner list, returning it to
// the unconfigured state. The next call from the active console-session
// user will then re-claim ownership. Requires root.
ResetOwner(context.Context, *ResetOwnerRequest) (*ResetOwnerResponse, error)
// Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
@@ -732,6 +768,12 @@ func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfi
func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method GetActiveProfile not implemented")
}
func (UnimplementedDaemonServiceServer) AddOwner(context.Context, *AddOwnerRequest) (*AddOwnerResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AddOwner not implemented")
}
func (UnimplementedDaemonServiceServer) ResetOwner(context.Context, *ResetOwnerRequest) (*ResetOwnerResponse, error) {
return nil, status.Error(codes.Unimplemented, "method ResetOwner not implemented")
}
func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) {
return nil, status.Error(codes.Unimplemented, "method Logout not implemented")
}
@@ -1291,6 +1333,42 @@ func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Contex
return interceptor(ctx, in, info, handler)
}
func _DaemonService_AddOwner_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AddOwnerRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).AddOwner(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_AddOwner_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).AddOwner(ctx, req.(*AddOwnerRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ResetOwner_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ResetOwnerRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ResetOwner(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_ResetOwner_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ResetOwner(ctx, req.(*ResetOwnerRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LogoutRequest)
if err := dec(in); err != nil {
@@ -1579,6 +1657,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetActiveProfile",
Handler: _DaemonService_GetActiveProfile_Handler,
},
{
MethodName: "AddOwner",
Handler: _DaemonService_AddOwner_Handler,
},
{
MethodName: "ResetOwner",
Handler: _DaemonService_ResetOwner_Handler,
},
{
MethodName: "Logout",
Handler: _DaemonService_Logout_Handler,

172
client/server/owner.go Normal file
View File

@@ -0,0 +1,172 @@
package server
import (
"context"
"fmt"
"slices"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/owner"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
// authorizeTargetProfile enforces the "match or root" rule for operations
// that target a specific profile (Remove/Switch). The caller must be root
// or appear in the target profile config's OwnerUIDs. A target profile in
// legacy TOFU state (nil OwnerUIDs) is treated as unowned and therefore
// accessible to any peer-creds caller, which matches pre-enforcement
// behavior on upgraded installs.
func (s *Server) authorizeTargetProfile(ctx context.Context, profileName, username string) error {
uid, ok := owner.UIDFromContext(ctx)
if !ok {
return status.Error(codes.PermissionDenied, "peer credentials unavailable")
}
if uid == 0 {
return nil
}
cfg, err := s.readProfileConfig(profileName, username)
if err != nil {
return fmt.Errorf("read target profile config: %w", err)
}
// Legacy / never-claimed target: allow, mirroring the migration TOFU
// semantics in the interceptor.
if cfg.OwnerUIDs == nil {
return nil
}
if slices.Contains(cfg.OwnerUIDs, uid) {
return nil
}
return status.Errorf(codes.PermissionDenied,
"profile %q is owned by another user (uid %d is not in its owner list)", profileName, uid)
}
// readProfileConfig loads a profile's config from disk without making it
// active. Used by authorizeTargetProfile.
func (s *Server) readProfileConfig(profileName, username string) (*profilemanager.Config, error) {
state := &profilemanager.ActiveProfileState{Name: profileName, Username: username}
path, err := state.FilePath()
if err != nil {
return nil, fmt.Errorf("resolve profile path: %w", err)
}
cfg, err := profilemanager.GetConfig(path)
if err != nil {
return nil, fmt.Errorf("load %s: %w", path, err)
}
return cfg, nil
}
// GetOwnerUIDs returns the current owner UIDs from the active config.
// nil means TOFU mode, empty means root-only, populated means those UIDs are owners.
func (s *Server) GetOwnerUIDs() []owner.UID {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.config == nil {
return nil
}
return s.config.OwnerUIDs
}
// AddOwnerUID adds the given UID to the owner list in the active profile config.
func (s *Server) AddOwnerUID(uid owner.UID) error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.addOwnerUIDLocked(uid)
}
// addOwnerUIDLocked adds uid to the active profile's owner list and persists it.
// The caller must hold s.mutex.
func (s *Server) addOwnerUIDLocked(uid owner.UID) error {
if s.config == nil {
return fmt.Errorf("config not loaded")
}
if slices.Contains(s.config.OwnerUIDs, uid) {
return nil
}
s.config.OwnerUIDs = append(s.config.OwnerUIDs, uid)
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
return fmt.Errorf("get active profile: %w", err)
}
cfgPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get profile file path: %w", err)
}
if err := util.WriteJson(context.Background(), cfgPath, s.config); err != nil {
return fmt.Errorf("write config: %w", err)
}
log.Infof("owner UID %d added in %s (owners: %v)", uid, cfgPath, s.config.OwnerUIDs)
return nil
}
// AddOwner handles the AddOwner RPC. The interceptor has already gated this
// call (caller must be root or an existing owner); the handler just persists
// the new UID into the active profile config.
func (s *Server) AddOwner(_ context.Context, msg *proto.AddOwnerRequest) (*proto.AddOwnerResponse, error) {
if msg == nil || msg.Uid == 0 {
return nil, status.Error(codes.InvalidArgument, "uid must be non-zero")
}
if err := s.AddOwnerUID(owner.UID(msg.Uid)); err != nil {
return nil, fmt.Errorf("add owner: %w", err)
}
return &proto.AddOwnerResponse{}, nil
}
// ResetOwner clears the active profile's owner list. Only callable by root
// (the interceptor enforces this: a non-owner non-root caller is denied
// before reaching the handler, and only owners or root can reach Add/Reset
// at all; we additionally require root here so existing owners can't reset
// each other out).
func (s *Server) ResetOwner(ctx context.Context, _ *proto.ResetOwnerRequest) (*proto.ResetOwnerResponse, error) {
uid, ok := owner.UIDFromContext(ctx)
if !ok {
return nil, status.Error(codes.PermissionDenied, "peer credentials unavailable")
}
if uid != 0 {
return nil, status.Error(codes.PermissionDenied, "reset-owner requires root")
}
s.mutex.Lock()
defer s.mutex.Unlock()
if s.config == nil {
return nil, fmt.Errorf("config not loaded")
}
// Reset to the fresh-install state (empty, not nil): only root and the
// active console-session user can reclaim. nil would be legacy migration
// TOFU, where any non-root caller (including SSH) could reclaim.
s.config.OwnerUIDs = []owner.UID{}
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
return nil, fmt.Errorf("get active profile: %w", err)
}
cfgPath, err := activeProf.FilePath()
if err != nil {
return nil, fmt.Errorf("get profile file path: %w", err)
}
if err := util.WriteJson(context.Background(), cfgPath, s.config); err != nil {
return nil, fmt.Errorf("write config: %w", err)
}
log.Infof("owner list reset; next call from the active console user will re-claim ownership")
return &proto.ResetOwnerResponse{}, nil
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/owner"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
"github.com/netbirdio/netbird/client/system"
@@ -735,6 +736,18 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
s.config = config
// An explicit --owner claim locks the active profile to the calling user
// (plus root). Root has no specific UID to claim, so only non-root callers
// take effect here; the interceptor has already authorized the call.
if msg != nil && msg.ClaimOwner {
if uid, ok := owner.UIDFromContext(callerCtx); ok && uid != 0 {
if err := s.addOwnerUIDLocked(uid); err != nil {
s.mutex.Unlock()
return nil, fmt.Errorf("claim owner: %w", err)
}
}
}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
@@ -800,6 +813,18 @@ func (s *Server) switchProfileIfNeeded(profileName string, userName *string, act
// SwitchProfile switches the active profile in the daemon.
func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfileRequest) (*proto.SwitchProfileResponse, error) {
// Switching downs the current session and starts another, so the caller
// must own the target profile (or be root).
if msg != nil && msg.ProfileName != nil {
username := ""
if msg.Username != nil {
username = *msg.Username
}
if err := s.authorizeTargetProfile(callerCtx, *msg.ProfileName, username); err != nil {
return nil, err
}
}
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -1564,7 +1589,17 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
// New profiles auto-claim the caller as their sole owner so the user who
// just created the profile retains control (and other local users can't
// touch it via SwitchProfile/RemoveProfile). When called by root, leave
// OwnerUIDs at the default (empty/env-seeded); root explicitly didn't
// claim ownership for any specific user.
var initialOwners []owner.UID
if uid, ok := owner.UIDFromContext(ctx); ok && uid != 0 {
initialOwners = []owner.UID{uid}
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username, initialOwners); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
@@ -1574,6 +1609,10 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
if err := s.authorizeTargetProfile(ctx, msg.ProfileName, msg.Username); err != nil {
return nil, err
}
s.mutex.Lock()
defer s.mutex.Unlock()

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)

View File

@@ -3,15 +3,14 @@
package system
import (
"bytes"
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
"golang.org/x/sys/unix"
log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo"
@@ -29,19 +28,11 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
info := _getInfo()
for strings.Contains(info, "broken pipe") {
info = _getInfo()
time.Sleep(500 * time.Millisecond)
}
osStr := strings.ReplaceAll(info, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
kernelName, kernelVersion, kernelPlatform := kernelInfo()
osName, osVersion := readOsReleaseFile()
if osName == "" {
osName = osInfo[3]
osName = kernelName
}
systemHostname, _ := os.Hostname()
@@ -58,8 +49,8 @@ func GetInfo(ctx context.Context) *Info {
}
gio := &Info{
Kernel: osInfo[0],
Platform: osInfo[2],
Kernel: kernelName,
Platform: kernelPlatform,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
@@ -67,7 +58,7 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
KernelVersion: kernelVersion,
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
@@ -78,18 +69,12 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func _getInfo() string {
cmd := exec.Command("uname", "-srio")
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Warnf("getInfo: %s", err)
func kernelInfo() (string, string, string) {
var uts unix.Utsname
if err := unix.Uname(&uts); err != nil {
return "", "", ""
}
return out.String()
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
}
func sysInfo() (string, string, string) {

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"sync"
"syscall/js"
"time"
@@ -13,7 +14,7 @@ import (
)
const (
certValidationTimeout = 60 * time.Second
certValidationTimeout = 5 * time.Minute
)
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool)
errorChan := make(chan error)
resultChan := make(chan bool, 1)
errorChan := make(chan error, 1)
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
result := args[0].Bool()
resultChan <- result
// Release from inside the callbacks so a post-timeout promise resolution
// does not invoke an already-released func.
var thenFn, catchFn js.Func
var releaseOnce sync.Once
release := func() {
releaseOnce.Do(func() {
thenFn.Release()
catchFn.Release()
})
}
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
defer release()
resultChan <- args[0].Bool()
return nil
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
})
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
defer release()
errorChan <- fmt.Errorf("certificate validation failed")
return nil
}))
})
promise.Call("then", thenFn).Call("catch", catchFn)
select {
case result := <-resultChan:

View File

@@ -11,6 +11,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"syscall/js"
"time"
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
}
activeConnections map[string]*proxyConnection
destinations map[string]string
pendingHandlers map[string]js.Func
nextID atomic.Uint64
mu sync.Mutex
}
@@ -66,8 +69,15 @@ type proxyConnection struct {
rdpConn net.Conn
tlsConn *tls.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
// global handle map and MUST be released, otherwise every connection
// leaks the Go memory the closure captures.
wsHandlerFn js.Func
onMessageFn js.Func
onCloseFn js.Func
cleanupOnce sync.Once
ctx context.Context
cancel context.CancelFunc
}
// NewRDCleanPathProxy creates a new RDCleanPath proxy
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
}
}
// CreateProxy creates a new proxy endpoint for the given destination
// CreateProxy creates a new proxy endpoint for the given destination.
// The registered handler fn and its destinations/pendingHandlers entries are
// only released once a connection is established and cleanupConnection runs.
// If a caller invokes CreateProxy but never connects to the returned URL,
// those entries stay pinned for the lifetime of the page.
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := net.JoinHostPort(hostname, port)
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
p.mu.Lock()
if p.destinations == nil {
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
ws := args[0]
p.HandleWebSocketConnection(ws, proxyID)
return nil
}))
})
p.mu.Lock()
if p.pendingHandlers == nil {
p.pendingHandlers = make(map[string]js.Func)
}
p.pendingHandlers[proxyID] = handlerFn
p.mu.Unlock()
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL)
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
p.mu.Lock()
p.activeConnections[proxyID] = conn
if fn, ok := p.pendingHandlers[proxyID]; ok {
conn.wsHandlerFn = fn
delete(p.pendingHandlers, proxyID)
}
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
}
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
}))
})
ws.Set("onGoMessage", conn.onMessageFn)
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript")
conn.cancel()
return nil
}))
})
ws.Set("onGoClose", conn.onCloseFn)
}
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
}
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
conn.cleanupOnce.Do(func() {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
}
conn.tlsConn = nil
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
conn.rdpConn = nil
}
conn.rdpConn = nil
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
p.mu.Unlock()
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
// Detach before releasing so late JS calls surface as TypeError instead
// of silent "call to released function".
if conn.wsHandlers.Truthy() {
conn.wsHandlers.Set("onGoMessage", js.Undefined())
conn.wsHandlers.Set("onGoClose", js.Undefined())
}
// wsHandlerFn may be zero-value if the pending handler lookup missed.
if conn.wsHandlerFn.Truthy() {
conn.wsHandlerFn.Release()
}
if conn.onMessageFn.Truthy() {
conn.onMessageFn.Release()
}
if conn.onCloseFn.Truthy() {
conn.onCloseFn.Release()
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
delete(p.destinations, conn.id)
delete(p.pendingHandlers, conn.id)
p.mu.Unlock()
})
}
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {

View File

@@ -13,7 +13,7 @@ import (
func CreateJSInterface(client *Client) js.Value {
jsInterface := js.Global().Get("Object").Call("create", js.Null())
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf(false)
}
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
_, err := client.Write(bytes)
return js.ValueOf(err == nil)
}))
})
jsInterface.Set("write", writeFunc)
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf(false)
}
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
rows := args[1].Int()
err := client.Resize(cols, rows)
return js.ValueOf(err == nil)
}))
})
jsInterface.Set("resize", resizeFunc)
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
client.Close()
return js.Undefined()
}))
})
jsInterface.Set("close", closeFunc)
go readLoop(client, jsInterface)
go func() {
readLoop(client, jsInterface)
// Detach before releasing so late JS calls surface as TypeError instead
// of silent "call to released function".
jsInterface.Set("write", js.Undefined())
jsInterface.Set("resize", js.Undefined())
jsInterface.Set("close", js.Undefined())
writeFunc.Release()
resizeFunc.Release()
closeFunc.Release()
}()
return jsInterface
}

View File

@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn listener.Conn)
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Embedded IdP (Dex)
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
idpHandler.ServeHTTP(w, r)
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)

12
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
go 1.25.5
require (
cunicu.li/go-rosenpass v0.4.0
cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4
@@ -19,8 +19,8 @@ require (
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
@@ -38,7 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/cilium/ebpf v0.19.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.18.0
@@ -60,7 +60,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.1.1
github.com/gopacket/gopacket v1.4.0
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

26
go.sum
View File

@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -499,8 +501,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=

View File

@@ -112,7 +112,7 @@ func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams()
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
@@ -175,6 +175,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
continue
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
@@ -242,14 +246,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID)
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID)
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}
@@ -265,7 +269,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
return c.sendUpdateAccountPeers(ctx, accountID)
return c.sendUpdateAccountPeers(ctx, accountID, reason)
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
@@ -359,14 +363,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID)
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID)
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}

View File

@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
found = true
select {
case channel <- update:
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))

View File

@@ -5,6 +5,7 @@ package peers
import (
"context"
"fmt"
"net"
"time"
"github.com/rs/xid"
@@ -35,6 +36,14 @@ type Manager interface {
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
// Returns nil with an error when no match exists. No permission check;
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
// to. Used by the proxy's auth path to authorise a request by the calling
// peer's group memberships.
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
}
type managerImpl struct {
@@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}
// GetPeerByTunnelIP delegates to the store's indexed lookup.
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
}
// GetPeerWithGroups returns the peer plus its group memberships. Any store
// error returns (nil, nil, err) so callers never receive a valid peer
// alongside a non-nil error.
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
return p, groups, nil
}
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {

View File

@@ -6,6 +6,7 @@ package peers
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -13,6 +14,7 @@ import (
account "github.com/netbirdio/netbird/management/server/account"
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer"
types "github.com/netbirdio/netbird/management/server/types"
)
// MockManager is a mock of Manager interface.
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}
// DeletePeers mocks base method.
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
m.ctrl.T.Helper()
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeerByTunnelIP mocks base method.
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
}
// GetPeerID mocks base method.
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
m.ctrl.T.Helper()
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
}
// GetPeerWithGroups mocks base method.
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].([]*types.Group)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -23,6 +23,8 @@ type Domain struct {
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
// Not persisted.
SupportsCrowdSec *bool `gorm:"-"`
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
SupportsPrivate *bool `gorm:"-"`
}
// EventMeta returns activity event metadata for a domain

View File

@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
SupportsCustomPorts: d.SupportsCustomPorts,
RequireSubdomain: d.RequireSubdomain,
SupportsCrowdsec: d.SupportsCrowdSec,
SupportsPrivate: d.SupportsPrivate,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster

View File

@@ -35,6 +35,7 @@ type proxyManager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
ret = append(ret, d)
}
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
if d.TargetCluster != "" {
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.

View File

@@ -10,7 +10,7 @@ import (
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil
}
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -19,6 +19,7 @@ type Manager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)

View File

@@ -21,6 +21,7 @@ type store interface {
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
}
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
}
return nil
}

View File

@@ -15,16 +15,16 @@ import (
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")

View File

@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
}
// ClusterSupportsPrivate mocks base method.
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()

View File

@@ -20,6 +20,9 @@ type Capabilities struct {
RequireSubdomain *bool
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
SupportsCrowdsec *bool
// Private indicates whether this proxy supports inbound access via Wireguard
// tunnel and netbird-only authentication policies
Private *bool
}
// Proxy represents a reverse proxy instance
@@ -67,10 +70,9 @@ type Cluster struct {
Type ClusterType
Online bool
ConnectedProxies int
// Capability flags. *bool because nil means "no proxy reported a
// capability for this cluster" — the dashboard renders these as
// unknown rather than false.
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
SupportsCustomPorts *bool
RequireSubdomain *bool
SupportsCrowdSec *bool
Private *bool
}

View File

@@ -204,6 +204,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdSec,
Private: c.Private,
})
}

View File

@@ -82,6 +82,7 @@ type CapabilityProvider interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
@@ -136,6 +137,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
}
return clusters, nil
@@ -208,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
target.Host = resource.Domain
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
case service.TargetTypeCluster:
// Cluster targets carry the upstream address on target_id; the
// proxy resolves the destination at request time.
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
@@ -779,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeCluster:
if err := validateClusterTarget(target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
@@ -786,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func validateClusterTarget(target *service.Target) error {
if !target.Options.DirectUpstream {
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
}
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
@@ -962,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
return fmt.Errorf("failed to get services: %w", err)
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
}
return nil

View File

@@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) {
})
}
}
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
// No peer or resource lookups must be issued for cluster targets.
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Options: rpservice.TargetOptions{DirectUpstream: true},
},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
}
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
// store-side check that cluster targets must opt into the host-stack dial
// path. Without DirectUpstream the proxy would route this target through
// the embedded NetBird client and fail on every request.
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "backend.lan",
},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
require.Error(t, err, "cluster target without direct_upstream must be rejected")
assert.ErrorContains(t, err, "direct upstream disabled")
}
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mgr := &Manager{store: mockStore}
svc := &rpservice.Service{
ID: "svc-1",
AccountID: accountID,
Targets: []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "127.0.0.1",
},
},
}
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
}

View File

@@ -45,10 +45,11 @@ const (
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
TargetTypeCluster TargetType = "cluster"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
@@ -60,6 +61,11 @@ type TargetOptions struct {
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
// the target via the proxy host's network stack. Useful for upstreams
// reachable without WireGuard (public APIs, LAN services, localhost
// sidecars). Default false.
DirectUpstream bool `json:"direct_upstream,omitempty"`
}
type Target struct {
@@ -67,7 +73,7 @@ type Target struct {
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Host string `json:"host"`
Port uint16 `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
@@ -200,6 +206,10 @@ type Service struct {
Mode string `gorm:"default:'http'"`
ListenPort uint16
PortAutoAssigned bool
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
Private bool
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
@@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service {
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
Private: &s.Private,
}
if len(s.AccessGroups) > 0 {
groups := append([]string(nil), s.AccessGroups...)
resp.AccessGroups = &groups
}
if s.ProxyCluster != "" {
@@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp
}
// ToProtoMapping converts the service into the wire format the proxy consumes.
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := s.buildPathMappings()
@@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
RewriteRedirects: s.RewriteRedirects,
Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec
Private: s.Private,
}
if r := restrictionsToProto(s.Restrictions); r != nil {
@@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
@@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
if opts.DirectUpstream {
apiOpts.DirectUpstream = &opts.DirectUpstream
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
DirectUpstream: opts.DirectUpstream,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
@@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
if o.DirectUpstream != nil {
opts.DirectUpstream = *o.DirectUpstream
}
return opts, nil
}
@@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
if req.ListenPort != nil {
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
}
if req.Private != nil {
s.Private = *req.Private
}
if req.AccessGroups != nil {
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
} else {
s.AccessGroups = nil
}
targets, err := targetsFromAPI(accountID, req.Targets)
if err != nil {
@@ -740,6 +775,9 @@ func (s *Service) Validate() error {
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
if err := s.validatePrivateRequirements(); err != nil {
return err
}
switch s.Mode {
case ModeHTTP:
@@ -753,6 +791,23 @@ func (s *Service) Validate() error {
}
}
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
func (s *Service) validatePrivateRequirements() error {
if !s.Private {
return nil
}
if s.Mode != "" && s.Mode != ModeHTTP {
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
}
if len(s.AccessGroups) == 0 {
return errors.New("private services require at least one access group")
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
}
return nil
}
func (s *Service) validateHTTPMode() error {
if s.Domain == "" {
return errors.New("service domain is required")
@@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error {
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
// Host is normally overwritten by replaceHostByLookup with the
// resolved peer IP / resource address; operator-supplied values
// are honored only when DirectUpstream is set. Validate the
// override here so misconfigured hosts fail fast at API time.
if err := validateDirectUpstreamHost(i, target); err != nil {
return err
}
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
case TargetTypeCluster:
if err := validateClusterTarget(i, target); err != nil {
return err
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
@@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error {
return nil
}
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
func validateClusterTarget(idx int, target *Target) error {
host := strings.TrimSpace(target.Host)
if host == "" {
return fmt.Errorf("target %d: has empty host", idx)
}
if !target.Options.DirectUpstream {
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
}
return validateDirectUpstreamHost(idx, target)
}
// validateDirectUpstreamHost validates the operator-supplied Host on a
// peer/host/domain target when DirectUpstream is set. Empty Host is
// allowed — the lookup fills in the default peer IP / resource address.
// Without DirectUpstream the Host value is silently overwritten by
// replaceHostByLookup, so we don't validate it (preserves the historical
// behaviour where APIs accepted any value and dropped it). Non-empty
// Host with DirectUpstream must look like a hostname or IP and must
// not carry a port (port lives on Target.Port).
func validateDirectUpstreamHost(idx int, target *Target) error {
if !target.Options.DirectUpstream {
return nil
}
host := strings.TrimSpace(target.Host)
if host == "" {
return nil
}
if strings.ContainsAny(host, " \t/") {
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
}
if _, _, err := net.SplitHostPort(host); err == nil {
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
}
return nil
}
func (s *Service) validateL4Target(target *Target) error {
// L4 services have a single target; per-target disable is meaningless
// (use the service-level Enabled flag instead). Force it on so that
// buildPathMappings always includes the target in the proto.
target.Enabled = true
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
if target.TargetId == "" {
return errors.New("target_id is required for L4 services")
}
if target.TargetType != TargetTypeCluster && target.Port == 0 {
return errors.New("target port is required for L4 services")
}
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// OK
if err := validateDirectUpstreamHost(0, target); err != nil {
return err
}
case TargetTypeSubnet:
if target.Host == "" {
return errors.New("target host is required for subnet targets")
}
case TargetTypeCluster:
// target_id carries the cluster address; the proxy resolves
// the upstream at request time.
default:
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
}
@@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service {
}
}
var accessGroups []string
if len(s.AccessGroups) > 0 {
accessGroups = append([]string(nil), s.AccessGroups...)
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
@@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service {
Mode: s.Mode,
ListenPort: s.ListenPort,
PortAutoAssigned: s.PortAutoAssigned,
Private: s.Private,
AccessGroups: accessGroups,
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) {
assert.Contains(t, err.Error(), "exceeds maximum length")
})
}
func TestValidate_HTTPClusterTarget(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
}
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
}
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
// rule that operator-supplied Host is mandatory: cluster targets dial the
// upstream via the host network stack (direct_upstream is implied), so an
// empty Host leaves the proxy with nothing to dial.
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
}
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
// half of the cluster-target rule: DirectUpstream must be true so the
// stdlib transport branch in MultiTransport is taken. Without it the
// embedded NetBird client would try to dial the cluster address through
// the WG tunnel, which is the wrong network for a cluster upstream.
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
}
func TestValidate_L4ClusterTarget(t *testing.T) {
rp := validProxy()
rp.Mode = ModeTCP
rp.ListenPort = 9000
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
}
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
svc := validProxy()
svc.Private = true
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
cp := svc.Copy()
require.NotNil(t, cp)
assert.True(t, cp.Private)
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
cp.Private = false
assert.True(t, svc.Private)
cp.AccessGroups[0] = "grp-other"
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
}
func TestService_APIRoundtrip_Private(t *testing.T) {
enabled := true
private := true
accessGroups := []string{"grp-admins"}
targets := []api.ServiceTarget{{
TargetId: "eu.proxy.netbird.io",
TargetType: api.ServiceTargetTargetType("cluster"),
Protocol: "http",
Port: 80,
Enabled: true,
}}
req := &api.ServiceRequest{
Name: "svc-private",
Domain: "myapp.eu.proxy.netbird.io",
Enabled: enabled,
Private: &private,
AccessGroups: &accessGroups,
Targets: &targets,
}
svc := &Service{}
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
assert.True(t, svc.Private)
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
resp := svc.ToAPIResponse()
require.NotNil(t, resp.Private)
assert.True(t, *resp.Private)
require.NotNil(t, resp.AccessGroups)
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
}
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "access group")
}
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-sso"},
}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
}
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Mode = ModeTCP
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "HTTP")
}

View File

@@ -20,6 +20,20 @@ type KeyPair struct {
type Claims struct {
jwt.RegisteredClaims
Method auth.Method `json:"method"`
// Email is the calling user's email address. Carried so the
// proxy can stamp identity on upstream requests (e.g.
// x-litellm-end-user-id) without an extra management
// round-trip on every cookie-bearing request.
Email string `json:"email,omitempty"`
// Groups carries the user's group IDs so the proxy can stamp them
// onto upstream requests (X-NetBird-Groups) from the cookie path
// without an extra management round-trip.
Groups []string `json:"groups,omitempty"`
// GroupNames carries the human-readable display names for the ids
// in Groups, ordered identically (positional pairing). Slice may be
// shorter than Groups for tokens minted before names were
// resolvable; the consumer falls back to ids for missing positions.
GroupNames []string `json:"group_names,omitempty"`
}
func GenerateKeyPair() (*KeyPair, error) {
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
}, nil
}
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
// SignToken mints a session JWT for the given user and domain. email,
// groups, and groupNames, when non-empty, are embedded so the proxy can
// authorise and stamp identity for policy-aware middlewares without a
// management round-trip on every cookie-bearing request. groupNames
// pairs positionally with groups; pass nil when names couldn't be
// resolved.
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
if err != nil {
return "", fmt.Errorf("decode private key: %w", err)
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
Method: method,
Method: method,
Email: email,
Groups: append([]string(nil), groups...),
GroupNames: append([]string(nil), groupNames...),
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)

View File

@@ -10,8 +10,10 @@ import (
"slices"
"time"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/cors"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -19,7 +21,6 @@ import (
"google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
@@ -27,16 +28,20 @@ import (
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/crypt"
)
const apiPrefix = "/api"
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
func (s *BaseServer) EventStore() activity.Store {
return Create(s, func() activity.Store {
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
if err != nil {
log.Fatalf("failed to initialize integration metrics: %v", err)
var err error
key := s.Config.DataStoreEncryptionKey
if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = crypt.GenerateKey()
if err != nil {
log.Fatalf("failed to generate event store encryption key: %v", err)
}
}
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
})
}
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
// the deployment isn't using the embedded variant.
func (s *BaseServer) IDPHandler() http.Handler {
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
if !ok || embeddedIdP == nil {
return nil
}
return cors.AllowAll().Handler(embeddedIdP.Handler())
}
func (s *BaseServer) Router() *mux.Router {
return Create(s, func() *mux.Router {
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(
integratedPeerValidator, err := validator.NewIntegratedValidator(
context.Background(),
s.PeersManager(),
s.SettingsManager(),

View File

@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
s.AfterInit(func(s *BaseServer) {
manager.SetAccountManager(s.AccountManager())
})
return manager
return permissions.NewManager(s.Store())
})
}
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
return idpManager
}
return nil
})
}
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return &m
})
}
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
return false
}

View File

@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
gRPCHandler.ServeHTTP(writer, request)
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(writer, request)
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
idpHandler.ServeHTTP(writer, request)
default:
httpHandler.ServeHTTP(writer, request)
}

View File

@@ -6,9 +6,11 @@ import (
"net/netip"
"net/url"
"strings"
"time"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
@@ -185,9 +187,38 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
// settings == nil → field stays nil → "no info in this snapshot", client
// preserves the deadline it already had. settings non-nil → emit either a
// valid deadline or the explicit-zero "disabled" sentinel via
// encodeSessionExpiresAt.
if settings != nil {
response.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
}
return response
}
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
// representation used on LoginResponse, SyncResponse and
// ExtendAuthSessionResponse. See the proto comments on those messages.
//
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
// its anchor.
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
// UTC deadline.
//
// Returning nil ("no info, preserve client's anchor") is the caller's job —
// only meaningful on Sync builds where settings were not resolved.
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
if deadline.IsZero() {
return &timestamppb.Timestamp{}
}
return timestamppb.New(deadline)
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte

View File

@@ -5,6 +5,7 @@ import (
"net/netip"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
@@ -200,3 +201,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
// "expiry disabled or peer is not SSO-tracked" sentinel.
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
//
// The third state (nil pointer = "no info in this snapshot") is the caller's
// responsibility on the Sync path when settings could not be resolved; the
// helper itself never returns nil.
func TestEncodeSessionExpiresAt(t *testing.T) {
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
got := encodeSessionExpiresAt(time.Time{})
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
assert.Equal(t, int64(0), got.GetSeconds())
assert.Equal(t, int32(0), got.GetNanos())
})
t.Run("non-zero deadline round-trips", func(t *testing.T) {
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
got := encodeSessionExpiresAt(deadline)
assert.NotNil(t, got)
assert.True(t, got.AsTime().Equal(deadline))
})
}

View File

@@ -351,6 +351,7 @@ func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
Private: c.Private,
}
}
@@ -754,6 +755,11 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
InitialSyncComplete: update.InitialSyncComplete,
}
}
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
connUpdate = filterMappingsForProxy(conn, connUpdate)
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
return true
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
@@ -882,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
}
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
// Old proxies that never reported capabilities are skipped for non-TLS L4
// mappings with a custom listen port, since they don't understand the
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
// are new enough to handle the mapping. TLS uses SNI routing and works on
// any proxy. Delete operations are always sent so proxies can clean up.
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
// Private mappings require SupportsPrivateService; custom-port L4 mappings
// require SupportsCustomPorts. Remove operations always pass so proxies can
// clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true
}
if mapping.GetPrivate() {
caps := conn.capabilities
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
return false
}
}
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true
}
@@ -900,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
}
// filterMappingsForProxy drops mappings the proxy cannot safely receive
// (e.g. private mappings to a proxy without SupportsPrivateService).
// Returns the input unchanged when no filtering is needed.
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
if update == nil || len(update.Mapping) == 0 {
return update
}
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, m := range update.Mapping {
if !proxyAcceptsMapping(conn, m) {
continue
}
kept = append(kept, m)
}
if len(kept) == len(update.Mapping) {
return update
}
return &proto.GetMappingUpdateResponse{
Mapping: kept,
InitialSyncComplete: update.InitialSyncComplete,
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
@@ -961,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
// secrets and have no user-level group context, so groups stay nil. Email
// is also empty — these schemes don't resolve a user record at sign time.
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
if err != nil {
return nil, err
}
@@ -1050,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
}
}
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
if !authenticated || service.SessionPrivateKey == "" {
return "", nil
}
@@ -1058,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
token, err := sessionkey.SignToken(
service.SessionPrivateKey,
userId,
userEmail,
service.Domain,
method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry,
)
if err != nil {
@@ -1070,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
return token, nil
}
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
// into parallel id and name slices. ids[i] and names[i] always pair to
// the same group. nil entries (orphan ids the manager couldn't resolve)
// are skipped so the consumer can rely on positional pairing.
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
if len(groups) == 0 {
return nil, nil
}
ids = make([]string, 0, len(groups))
names = make([]string, 0, len(groups))
for _, g := range groups {
if g == nil {
continue
}
ids = append(ids, g.ID)
names = append(names, g.Name)
}
return ids, names
}
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
@@ -1334,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return verifier, redirectURL, nil
}
// GenerateSessionToken creates a signed session JWT for the given domain and user.
// GenerateSessionToken creates a signed session JWT for the given domain and
// user. The user's group memberships are embedded in the token so policy-aware
// middlewares on the proxy can authorise without an extra management round-trip.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
@@ -1345,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
var (
email string
groupIDs []string
groupNames []string
)
if s.usersManager != nil {
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
if uerr != nil {
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
} else if user != nil {
email = user.Email
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
}
}
return sessionkey.SignToken(
service.SessionPrivateKey,
userID,
email,
domain,
method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry,
)
}
@@ -1453,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -1466,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
user, err := s.usersManager.GetUser(ctx, userID)
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -1500,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: access denied")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
@@ -1515,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"email": user.Email,
}).Debug("ValidateSession: access granted")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
return &proto.ValidateSessionResponse{
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
@@ -1551,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
}
func ptr[T any](v T) *T { return &v }
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
// checks the peer's group membership against the service's access groups.
// Peers without a user (machine agents, automation workloads) are first-class
// callers; authorisation runs off peer-group memberships rather than the
// optional owning user's auto-groups. On success a session JWT is minted so
// the proxy can install a cookie and skip subsequent management round-trips.
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
domain := req.GetDomain()
tunnelIPStr := req.GetTunnelIp()
if domain == "" || tunnelIPStr == "" {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "missing domain or tunnel_ip",
}, nil
}
tunnelIP := net.ParseIP(tunnelIPStr)
if tunnelIP == nil {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "invalid_tunnel_ip",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
// validate and mint session cookies for their own account's domains.
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
if err != nil || peer == nil {
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
UserId: principalID,
UserEmail: displayIdentity,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
if err != nil {
return nil, err
}
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
"peer_id": peer.ID,
"principal_id": principalID,
}).Debug("ValidateTunnelPeer: access granted")
return &proto.ValidateTunnelPeerResponse{
Valid: true,
UserId: principalID,
UserEmail: displayIdentity,
SessionToken: token,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security
// boundary and must not trust upstream state). Bearer-auth services authorise
// against DistributionGroups when populated. Non-private non-bearer services
// are open.
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
if service.Private {
if len(service.AccessGroups) == 0 {
return fmt.Errorf("private service has no access groups")
}
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
}
return nil
}
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
// else a non-nil error.
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
if len(allowedGroups) == 0 {
return fmt.Errorf("no allowed groups configured")
}
allowed := make(map[string]struct{}, len(allowedGroups))
for _, g := range allowedGroups {
allowed[g] = struct{}{}
}
for _, g := range peerGroupIDs {
if _, ok := allowed[g]; ok {
return nil
}
}
return fmt.Errorf("peer not in allowed groups")
}

View File

@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
return user, nil
}
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.GetUser(ctx, userID)
if err != nil {
return nil, nil, err
}
return user, nil, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
})
}
}
func TestCheckPeerGroupAccess(t *testing.T) {
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: nil}
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
require.Error(t, err)
assert.Contains(t, err.Error(), "no access groups")
})
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
})
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
})
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-allowed"},
},
},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("non-private non-bearer is open", func(t *testing.T) {
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
})
}

View File

@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
return nil
}
@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}, nil
}
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
// stays up; no network map sync is performed. The new deadline is returned
// in ExtendAuthSessionResponse.SessionExpiresAt.
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
extendReq := &proto.ExtendAuthSessionRequest{}
peerKey, err := s.parseRequest(ctx, req, extendReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
}
jwt := extendReq.GetJwtToken()
if jwt == "" {
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
}
var userID string
const attempts = 3
for i := 0; i < attempts; i++ {
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
if err == nil {
break
}
if i == attempts-1 {
break
}
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
select {
case <-time.After(200 * time.Millisecond):
case <-ctx.Done():
return nil, ctx.Err()
}
}
if err != nil {
return nil, err
}
if userID == "" {
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
}
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
if err != nil {
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
return nil, mapError(ctx, err)
}
// Success path normally returns a non-zero deadline. A defensive zero
// would still encode as the explicit "disabled" sentinel rather than nil,
// so the client clears any stale anchor instead of preserving it.
resp := &proto.ExtendAuthSessionResponse{
SessionExpiresAt: encodeSessionExpiresAt(deadline),
}
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed processing request")
}
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed encrypting response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encrypted,
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
Checks: toProtocolChecks(ctx, postureChecks),
}
// settings is always non-nil here, so we never emit nil — encoder returns
// either a valid deadline or the explicit-zero "disabled" sentinel.
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
return loginResp, nil
}

View File

@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
require.NoError(t, err)
return token
}
@@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
@@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {

View File

@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
// Session deadline is derived from LastLogin + PeerLoginExpiration
// on every Login/Sync response. Without a fan-out push, connected
// peers keep the deadline they received at login time and only see
// the new value after the next unrelated NetworkMap change. Add
// these two fields to the trigger list so admin-side expiry tweaks
// (e.g. shortening from 24h to 1h) reach every connected peer
// within seconds, which is what the proactive-warning feature
// relies on (see client/internal/auth/sessionwatch).
updateAccountPeers = true
}

View File

@@ -109,6 +109,7 @@ type Manager interface {
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)

View File

@@ -1304,6 +1304,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
}
// ExtendPeerSession mocks base method.
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
ret0, _ := ret[0].(time.Time)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()

View File

@@ -240,6 +240,10 @@ const (
AccountLocalMfaEnabled Activity = 123
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
AccountLocalMfaDisabled Activity = 124
// UserExtendedPeerSession indicates that a user refreshed their peer's
// SSO session deadline via ExtendAuthSession without re-establishing the
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
UserExtendedPeerSession Activity = 125
AccountDeleted Activity = 99999
)
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},

View File

@@ -15,15 +15,13 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -32,12 +30,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
@@ -56,17 +52,14 @@ import (
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const apiPrefix = "/api"
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetUserFromUserAuth,
rateLimiter,
appMetrics.GetMeter(),
isValidChildAccount,
)
corsMiddleware := cors.AllowAll()
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
// Check if embedded IdP is enabled for instance manager
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
oauthHandler.RegisterEndpoints(router)
}
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
}
return rootRouter, nil
return router, nil
}

View File

@@ -11,8 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
authManager serverauth.Manager
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
patUsageTracker *PATUsageTracker
isValidChildAccount IsValidChildAccountFunc
}
// NewAuthMiddleware instance constructor
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiter *APIRateLimiter,
meter metric.Meter,
isValidChildAccount IsValidChildAccountFunc,
) *AuthMiddleware {
var patUsageTracker *PATUsageTracker
if meter != nil {
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
patUsageTracker: patUsageTracker,
isValidChildAccount: isValidChildAccount,
}
}
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}

View File

@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
},
disabledLimiter,
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
disabledLimiter,
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
for _, tc := range tt {

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop"
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -0,0 +1,62 @@
package validator
import (
"context"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type IntegratedValidatorImpl struct{}
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
return &IntegratedValidatorImpl{}, nil
}
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
return nil
}
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
return peer.Copy()
}
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, p := range peers {
validatedPeers[p.ID] = struct{}{}
}
return validatedPeers, nil
}
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
return make(map[string]string), nil
}
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
return nil
}
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
}
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
}
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
return flowResponse
}

View File

@@ -17,6 +17,7 @@ import (
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version"
)
@@ -53,6 +54,7 @@ type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() types.Engine
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
}
// ConnManager peer connection manager that holds state for current active connections
@@ -223,6 +225,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
servicesAuthPassword int
servicesAuthPin int
servicesAuthOIDC int
// Private-service signals — track adoption of NetBird-only mode
// (services backed by an embedded proxy peer + access groups).
servicesPrivate int
servicesPrivateWithGroups int
servicesPrivateAccessGroupsSum int
servicesWithDirectUpstream int
)
start := time.Now()
metricsProperties := make(properties)
@@ -380,9 +388,31 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
servicesAuthOIDC++
}
if service.Private {
servicesPrivate++
if len(service.AccessGroups) > 0 {
servicesPrivateWithGroups++
}
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
}
for _, target := range service.Targets {
if target.Options.DirectUpstream {
servicesWithDirectUpstream++
break
}
}
}
}
// Proxy / BYOP cluster signals come from the proxies table aggregated
// across all accounts in a single store query; nil on FileStore.
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
if err != nil {
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
}
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
metricsProperties["uptime"] = uptime
metricsProperties["accounts"] = accounts
@@ -430,6 +460,15 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["services_auth_password"] = servicesAuthPassword
metricsProperties["services_auth_pin"] = servicesAuthPin
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
metricsProperties["services_private"] = servicesPrivate
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
metricsProperties["proxies"] = proxyMetrics.Proxies
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
metricsProperties["custom_domains"] = customDomains
metricsProperties["custom_domains_validated"] = customDomainsValidated

Some files were not shown because too many files have changed in this diff Show More