Compare commits

..

32 Commits

Author SHA1 Message Date
crn4
eab0826b4e merge main 2026-05-29 14:08:09 +02:00
crn4
7048b87931 fix wasm bin filesize - type aliases 2026-05-29 13:50:06 +02:00
Zoltan Papp
174dc24867 [management] Add SSO session extend flow (management) (#6197)
* add SSO session extend flow (management)

Adds the management-server half of the SSO session-extension feature:

- New ExtendAuthSession gRPC RPC that refreshes a peer's session expiry
  using a fresh JWT, validated through the same pipeline as Login but
  without tearing down the tunnel or redoing the NetworkMap sync.
- Per-peer SessionExpiresAt timestamp on every LoginResponse and
  SyncResponse so connected clients learn the deadline on the existing
  long-lived stream, and admin-side changes (toggling expiration,
  changing the expiration window) reach every peer within seconds.
- SessionExpiresAt(...) helper on Peer that derives the absolute UTC
  deadline from LastLogin + the account-level PeerLoginExpiration
  setting, returning zero when the peer is not SSO-tracked or expiration
  is disabled.

The matching client-side consumer of these fields lands separately.

* encode SessionExpiresAt as 3-state on the wire

Previously the `sessionExpiresAt` field on LoginResponse, SyncResponse
and ExtendAuthSessionResponse was 2-state: a valid timestamp meant
"new deadline", and nil meant "clear". That conflated two distinct
meanings — "no info in this snapshot" vs "expiry is explicitly off /
peer is not SSO-tracked" — so a Sync push that legitimately couldn't
compute the deadline (settings lookup failed) would silently clear the
client's anchor and lose the warning window.

Three states now, encoded on the same field number (no .proto schema
churn — only comments and the server-side encoder change):

  - nil pointer (field absent) → "no info"; client preserves anchor
  - &Timestamp{} (seconds=0, nanos=0) → explicit "disabled / not SSO"
    sentinel; client clears
  - valid timestamp → new absolute UTC deadline

A new encodeSessionExpiresAt helper centralises the zero/non-zero
encoding and is shared by the Sync, Login and ExtendAuthSession
builders. The Sync builder still emits nil when settings are missing.
Login and ExtendAuthSession always carry an authoritative value.

The matching client-side decoder lands on feature/session-extend.

* add UserExtendedPeerSession activity event

ExtendAuthSession previously reused UserLoggedInPeer for its audit
record, which conflated two distinct user actions: a full interactive
SSO login (tunnel re-established, network map resync) versus an
in-place deadline refresh (tunnel untouched). Auditors reading the log
couldn't tell which one happened, and downstream dashboards/alerts on
"login" volume were polluted by routine extends.

Adds a dedicated UserExtendedPeerSession Activity (code 125,
"user.peer.session.extend") and switches ExtendPeerSession over to it.
The peer-extend audit trail is now distinguishable from interactive
logins.

* make ExtendAuthSession JWT-retry backoff cancellable

Skip the retry log and 200ms wait on the final attempt, and replace the
uncancellable time.Sleep with a select on time.After/ctx.Done so an
upstream cancellation aborts the wait instead of running it to
completion.
2026-05-28 19:14:14 +02:00
crn4
596952265d wgkey as peer id 2026-05-28 12:52:15 +02:00
Riccardo Manfrin
7ea5e37dd4 [client] Improve rosenpass support (#6136)
* Updates rosenpass version

go-rosenpass v0.4.0 → v0.5.42 bump — detailed findings

Change summary
cunicu.li/go-rosenpass  v0.4.0  → v0.5.42   (target)
cilium/ebpf             v0.15.0 → v0.19.0   (transitive)
gopacket/gopacket       v1.1.1  → v1.4.0    (transitive)
wireguard               2023-07 → 2023-12   (transitive)
wireguard/wgctrl        2023-04 → 2024-12   (transitive)

Wire interop

v0.4.0 (in v0.70.5) <-> v0.5.42 OK
v0.5.42 <-> v0.5.42 OK

Quantum resistance: true both ends

---
**Replay error eliminated.**

Before (on v0.4.0):

`ERROR Failed to handle message: failed to load biscuit (ICR1): detected replay`

Recurring every ~50ms for minutes at a time. Gone entirely after both ends upgraded to v0.5.42. Upstream fix in biscuit/replay handling between v0.4.x and v0.5.x series.

* Fixup [::]:port socket trying to send to v4

* Adds more tests on netbird<->rosenpass interactions

* Anticipates rp handler creation before generateConfig

* [client] Moves deterministic key gen into rosenpass

* go mod tidy

* Adds reminder to reason about rosenpass surface area

* Apply code rabbit suggestions
2026-05-28 09:01:18 +02:00
Riccardo Manfrin
9d7ef9b255 [client] Fix statemanager possible deadlock (#6228)
1. Stop() takes m.mu.Lock() and defers m.mu.Unlock()
2. <-m.done under lock
3. periodicStateSave defers close(m.done)
4. periodicStateSave calls PersistState() (line 256) which does m.mu.Lock()

Double Stop() remains idempotent: second cancel() on dead ctx
 (no-op) and reads done already closed (immediate return).
2026-05-28 08:54:15 +02:00
Pascal Fischer
944a258459 [management] extend nmap monitoring (#6271) 2026-05-27 16:56:02 +02:00
crn4
21cfec93d4 comments cleanup 2026-05-27 16:51:55 +02:00
Pascal Fischer
1f9a829f2c [management] update log levels (#6266) 2026-05-27 11:43:49 +02:00
crn4
98818e3095 merge main\ 2026-05-27 10:50:24 +02:00
Bethuel Mmbaga
14af179556 [management] Refactor management server bootstrap (#6256) 2026-05-26 17:44:28 +03:00
Pascal Fischer
1fbb5e6d5d [management] fix owner role update (#6264) 2026-05-26 16:37:58 +02:00
Viktor Liu
6771e35d57 [client] Release js.FuncOf callbacks in wasm ssh and rdp to prevent leaks (#5982) 2026-05-26 14:32:39 +02:00
Viktor Liu
e89b1e0596 [proxy, client] Bound embed client WireGuard per-Device memory (#5962) 2026-05-26 11:51:53 +02:00
crn4
5d5c2d9f95 filtering fix 2026-05-26 11:33:21 +02:00
Philip Laine
d542c60e21 Refactor Linux system info to use syscalls (#6230) 2026-05-25 21:00:24 +02:00
Viktor Liu
4983b5cf17 [client] Match DNS wildcard handlers on label boundaries (#6255) 2026-05-25 18:38:48 +02:00
Viktor Liu
b3b0feb3b8 [client] Filter scoped/cloned default routes from BSD network monitor RTM_ADD (#6208) 2026-05-25 18:38:21 +02:00
Maycon Santos
7aebdd69dd [management, client, proxy] add expose NetBird-only services over tunnel peers (#6226)
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
2026-05-25 17:41:50 +02:00
crn4
13e41e432c idp dex fix 2026-05-21 15:21:28 +02:00
crn4
efa6a3f502 missed file 2026-05-20 12:41:05 +02:00
crn4
5fbcdeceac more comments 2026-05-19 21:41:08 +02:00
crn4
3a1bbeba90 review comments 2026-05-19 20:27:50 +02:00
crn4
728057ef15 missed files for client side and shared files 2026-05-19 14:46:23 +02:00
crn4
582cd70086 client side and components on shared folder 2026-05-19 14:46:09 +02:00
crn4
9bbbafaf69 int id for networks and posture checks migration 2026-05-19 14:45:40 +02:00
crn4
672b057aa0 fix Group.Copy losing AccountSeqID 2026-05-19 14:43:59 +02:00
crn4
b9a0186200 fix routes filtering in account componnents 2026-05-19 14:43:49 +02:00
crn4
9083bdb977 capabilities conditioning 2026-05-19 14:43:38 +02:00
crn4
b194af48b8 wire size benches fix 2026-05-19 14:43:28 +02:00
crn4
4543780ef0 grpc components encoding with optimisations 2026-05-19 14:43:17 +02:00
crn4
2de0283971 init int inds migration 2026-05-19 14:42:55 +02:00
204 changed files with 19996 additions and 3361 deletions

View File

@@ -35,7 +35,7 @@ jobs:
display_name: Linux display_name: Linux
name: ${{ matrix.display_name }} name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
timeout-minutes: 15 timeout-minutes: 25
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -58,4 +58,4 @@ jobs:
skip-cache: true skip-cache: true
skip-save-cache: true skip-save-cache: true
cache-invalidation-interval: 0 cache-invalidation-interval: 0
args: --timeout=12m args: --timeout=20m

View File

@@ -20,34 +20,66 @@ jobs:
per_page: 100, per_page: 100,
}); });
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go')); const modifiedPbFiles = files.filter(
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename); f => f.filename.endsWith('.pb.go') && f.status === 'modified'
if (missingPatch.length > 0) { );
core.setFailed( if (modifiedPbFiles.length === 0) {
`Cannot inspect patch data for:\n` + console.log('No modified .pb.go files to check');
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return; return;
} }
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) { const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const changed = file.patch const baseSha = context.payload.pull_request.base.sha;
.split('\n') const headSha = context.payload.pull_request.head.sha;
.filter(line => versionPattern.test(line));
if (changed.length > 0) { async function getVersionHeader(path, ref) {
try {
const res = await github.rest.repos.getContent({
owner: context.repo.owner,
repo: context.repo.repo,
path,
ref,
});
if (!res.data.content) {
return { ok: false, reason: 'no inline content (file too large)' };
}
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
const lines = content
.split('\n')
.slice(0, 20)
.filter(line => versionPattern.test(line));
return { ok: true, lines };
} catch (e) {
return { ok: false, reason: e.message };
}
}
const violations = [];
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
violations.push({ violations.push({
file: file.filename, file: file.filename,
lines: changed, base: base.lines,
head: head.lines,
}); });
} }
} }
if (violations.length > 0) { if (violations.length > 0) {
const details = violations.map(v => const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}` `${v.file}:\n` +
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
).join('\n\n'); ).join('\n\n');
core.setFailed( core.setFailed(

View File

@@ -11,7 +11,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Fatal(err) t.Fatal(err)
} }
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx) metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -12,6 +12,7 @@ import (
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@@ -84,6 +85,12 @@ type Options struct {
DisableIPv6 bool DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers // BlockInbound blocks all inbound connections from peers
BlockInbound bool BlockInbound bool
// BlockLANAccess blocks the embedded peer from reaching the host's
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
// when the embedded client must never act as a stepping stone into
// the host's local network (e.g. the proxy's overlay peer).
BlockLANAccess bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port. // WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int WireguardPort *int
// MTU is the MTU for the tunnel interface. // MTU is the MTU for the tunnel interface.
@@ -94,6 +101,26 @@ type Options struct {
MTU *uint16 MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer. // DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string DNSLabels []string
// Performance configures the tunnel's buffer pool cap and batch size.
Performance Performance
}
// Performance configures the embedded client's tunnel memory/throughput knobs.
//
// These settings are process-global: any non-nil field also becomes the
// default for Clients constructed by later embed.New calls in the same
// process. Nil fields are ignored.
type Performance struct {
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
// leaves the pool unbounded. Lower values trade throughput for a
// tighter memory ceiling. May also be changed on a running Client via
// Client.SetPerformance, provided this field was nonzero at construction.
PreallocatedBuffersPerPool *uint32
// MaxBatchSize overrides the number of packets the tunnel reads or
// writes per syscall, which also bounds eager buffer allocation per
// worker. Zero uses the platform default. Applied at construction
// only; ignored by Client.SetPerformance.
MaxBatchSize *uint32
} }
// validateCredentials checks that exactly one credential type is provided // validateCredentials checks that exactly one credential type is provided
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6, DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound, BlockInbound: &opts.BlockInbound,
BlockLANAccess: &opts.BlockLANAccess,
WireguardPort: opts.WireguardPort, WireguardPort: opts.WireguardPort,
MTU: opts.MTU, MTU: opts.MTU,
DNSLabels: parsedLabels, DNSLabels: parsedLabels,
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
config.PrivateKey = opts.PrivateKey config.PrivateKey = opts.PrivateKey
} }
if opts.Performance.PreallocatedBuffersPerPool != nil {
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
}
if opts.Performance.MaxBatchSize != nil {
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
}
return &Client{ return &Client{
deviceName: opts.DeviceName, deviceName: opts.DeviceName,
setupKey: opts.SetupKey, setupKey: opts.SetupKey,
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
}, nil }, nil
} }
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false
}
state, found := c.recorder.PeerStateByIP(ip.String())
if !found {
return "", "", false
}
return state.PubKey, state.FQDN, true
}
// Status returns the current status of the client. // Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) { func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock() c.mu.Lock()
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress) return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
} }
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
// takes effect, and only when it was nonzero at construction;
// MaxBatchSize is construction-only and returns an error if set here.
//
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
// running yet.
func (c *Client) SetPerformance(t Performance) error {
if t.MaxBatchSize != nil {
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
}
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetPerformance(internal.Performance{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// StartCapture begins capturing packets on this client's tunnel device. // StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous. // Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it. // Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -1,199 +0,0 @@
package iptables
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func iptRefcountIfaceV4() *iFaceMock {
return &iFaceMock{
NameFunc: func() string { return "wt-refcount" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
}
func iptRefcountIfaceDual() *iFaceMock {
return &iFaceMock{
NameFunc: func() string { return "wt-refcount" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
}
func newIptRefcountManager(t *testing.T, dual bool) *Manager {
t.Helper()
var ifMock *iFaceMock
if dual {
ifMock = iptRefcountIfaceDual()
} else {
ifMock = iptRefcountIfaceV4()
}
m, err := Create(ifMock, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, m.Init(nil), "init manager")
t.Cleanup(func() {
require.NoError(t, m.Close(nil), "close manager")
})
return m
}
func iptDnatV4(port uint16) fw.ForwardRule {
return fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{port}},
TranslatedAddress: netip.MustParseAddr("10.20.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
}
}
func iptDnatV6(port uint16) fw.ForwardRule {
return fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{port}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
}
}
// TestIptablesDNAT_RefcountBalancedV4 covers a Balanced Add/Delete pair on v4.
func TestIptablesDNAT_RefcountBalancedV4(t *testing.T) {
m := newIptRefcountManager(t, false)
state := m.router.ipFwdState
r1, err := m.AddDNATRule(iptDnatV4(7081))
require.NoError(t, err, "add v4 dnat 1")
v4, v6 := state.Counts()
require.Equal(t, 1, v4, "v4 refcount after first add")
require.Equal(t, 0, v6, "v6 refcount unchanged")
r2, err := m.AddDNATRule(iptDnatV4(7082))
require.NoError(t, err, "add v4 dnat 2")
v4, v6 = state.Counts()
require.Equal(t, 2, v4, "v4 refcount after second add")
require.Equal(t, 0, v6, "v6 refcount unchanged")
require.NoError(t, m.DeleteDNATRule(r1))
v4, v6 = state.Counts()
require.Equal(t, 1, v4, "v4 refcount after first delete")
require.Equal(t, 0, v6, "v6 refcount unchanged")
require.NoError(t, m.DeleteDNATRule(r2))
v4, v6 = state.Counts()
require.Equal(t, 0, v4, "v4 refcount after second delete")
require.Equal(t, 0, v6, "v6 refcount unchanged")
}
// TestIptablesDNAT_RefcountBalancedV6 checks the v6 path increments v6 only and
// decrements back to zero.
func TestIptablesDNAT_RefcountBalancedV6(t *testing.T) {
m := newIptRefcountManager(t, true)
require.NotNil(t, m.router6, "v6 router")
require.Same(t, m.router.ipFwdState, m.router6.ipFwdState, "shared state")
state := m.router.ipFwdState
r1, err := m.AddDNATRule(iptDnatV6(9081))
require.NoError(t, err, "add v6 dnat 1")
v4, v6 := state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 1, v6, "v6 refcount after first add")
r2, err := m.AddDNATRule(iptDnatV6(9082))
require.NoError(t, err, "add v6 dnat 2")
v4, v6 = state.Counts()
require.Equal(t, 0, v4, "v4 refcount unchanged")
require.Equal(t, 2, v6, "v6 refcount after second add")
require.NoError(t, m.DeleteDNATRule(r1))
v4, v6 = state.Counts()
require.Equal(t, 0, v4, "v4 refcount unchanged")
require.Equal(t, 1, v6, "v6 refcount after first delete")
require.NoError(t, m.DeleteDNATRule(r2))
v4, v6 = state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 0, v6, "v6 refcount after second delete")
}
// TestIptablesDNAT_DuplicateAddNoLeak verifies the duplicate-rule path returns
// without bumping the refcount.
func TestIptablesDNAT_DuplicateAddNoLeak(t *testing.T) {
m := newIptRefcountManager(t, true)
state := m.router.ipFwdState
rule := iptDnatV4(7083)
r1, err := m.AddDNATRule(rule)
require.NoError(t, err)
v4, _ := state.Counts()
require.Equal(t, 1, v4)
_, err = m.AddDNATRule(rule)
require.NoError(t, err, "duplicate add")
v4, _ = state.Counts()
require.Equal(t, 1, v4, "duplicate add must not increment")
require.NoError(t, m.DeleteDNATRule(r1))
v4, _ = state.Counts()
require.Equal(t, 0, v4, "single delete must drop to zero")
}
// TestIptablesDNAT_DeleteMissingNoUnderflow verifies Delete on an unknown rule
// neither errors nor releases the refcount.
func TestIptablesDNAT_DeleteMissingNoUnderflow(t *testing.T) {
m := newIptRefcountManager(t, true)
state := m.router.ipFwdState
phantom := iptDnatV4(7099)
require.NoError(t, m.DeleteDNATRule(&phantom), "delete missing v4")
v4, v6 := state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 0, v6)
phantom6 := iptDnatV6(9099)
require.NoError(t, m.DeleteDNATRule(&phantom6), "delete missing v6")
v4, v6 = state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 0, v6)
r1, err := m.AddDNATRule(iptDnatV4(7100))
require.NoError(t, err)
v4, _ = state.Counts()
require.Equal(t, 1, v4, "real add still increments after phantom delete")
require.NoError(t, m.DeleteDNATRule(r1))
}
// TestIptablesDNAT_DoubleDeleteNoUnderflow verifies a second Delete on the same
// rule is a no-op.
func TestIptablesDNAT_DoubleDeleteNoUnderflow(t *testing.T) {
m := newIptRefcountManager(t, true)
state := m.router.ipFwdState
r1, err := m.AddDNATRule(iptDnatV6(9083))
require.NoError(t, err)
_, v6 := state.Counts()
require.Equal(t, 1, v6)
require.NoError(t, m.DeleteDNATRule(r1), "first delete")
_, v6 = state.Counts()
require.Equal(t, 0, v6)
require.NoError(t, m.DeleteDNATRule(r1), "second delete must be no-op")
_, v6 = state.Counts()
require.Equal(t, 0, v6, "double delete must not underflow")
}

View File

@@ -89,7 +89,7 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
} }
// Share the same IP forwarding state with the v4 router, since // Share the same IP forwarding state with the v4 router, since
// Forwarding refcounter is per-family but shared between v4 and v6 routers. // EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState m.router6.ipFwdState = m.router.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface) m.aclMgr6, err = newAclManager(ip6Client, wgIface)
@@ -402,33 +402,17 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(false); err != nil { if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IPv4 forwarding: %w", err) return fmt.Errorf("enable IP forwarding: %w", err)
}
// v6 only when the overlay actually has v6.
if m.router6 == nil {
return nil
}
if err := m.router.ipFwdState.RequestForwarding(true); err != nil {
if rerr := m.router.ipFwdState.ReleaseForwarding(false); rerr != nil {
log.Warnf("rollback v4 forwarding: %v", rerr)
}
return fmt.Errorf("enable IPv6 forwarding: %w", err)
} }
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
var merr *multierror.Error if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.router.ipFwdState.ReleaseForwarding(false); err != nil { return fmt.Errorf("disable IP forwarding: %w", err)
merr = multierror.Append(merr, fmt.Errorf("disable IPv4 forwarding: %w", err))
} }
if m.router6 != nil { return nil
if err := m.router.ipFwdState.ReleaseForwarding(true); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable IPv6 forwarding: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AddDNATRule adds a DNAT rule // AddDNATRule adds a DNAT rule

View File

@@ -101,7 +101,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
wgIface: wgIface, wgIface: wgIface,
mtu: mtu, mtu: mtu,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6, v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
ipFwdState: ipfwdstate.NewIPForwardingState(wgIface.Name()), ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -763,6 +763,10 @@ func (r *router) updateState() {
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID() ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists { if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil return rule, nil
@@ -837,16 +841,6 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
r.rules[key] = ruleInfo.rule r.rules[key] = ruleInfo.rule
} }
if err := r.ipFwdState.RequestForwarding(r.v6); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
for key := range rules {
delete(r.rules, key)
}
return nil, fmt.Errorf("enable forwarding: %w", err)
}
r.updateState() r.updateState()
return rule, nil return rule, nil
} }
@@ -867,15 +861,12 @@ func (r *router) rollbackRules(rules map[string]ruleInfo) error {
} }
func (r *router) DeleteDNATRule(rule firewall.Rule) error { func (r *router) DeleteDNATRule(rule firewall.Rule) error {
ruleKey := rule.ID() if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
_, hadDNAT := r.rules[ruleKey+dnatSuffix]
_, hadSNAT := r.rules[ruleKey+snatSuffix]
_, hadFWD := r.rules[ruleKey+fwdSuffix]
if !hadDNAT && !hadSNAT && !hadFWD {
return nil
} }
ruleKey := rule.ID()
var merr *multierror.Error var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists { if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
@@ -898,10 +889,6 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
delete(r.rules, ruleKey+fwdSuffix) delete(r.rules, ruleKey+fwdSuffix)
} }
if err := r.ipFwdState.ReleaseForwarding(r.v6); err != nil {
log.Errorf("%v", err)
}
r.updateState() r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -1,208 +0,0 @@
package nftables
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func nftRefcountIfaceV4() *iFaceMock {
return &iFaceMock{
NameFunc: func() string { return "wt-refcount" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
}
},
}
}
func nftRefcountIfaceDual() *iFaceMock {
return &iFaceMock{
NameFunc: func() string { return "wt-refcount" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
}
func newNftRefcountManager(t *testing.T, dual bool) *Manager {
t.Helper()
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
var ifMock *iFaceMock
if dual {
ifMock = nftRefcountIfaceDual()
} else {
ifMock = nftRefcountIfaceV4()
}
m, err := Create(ifMock, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, m.Init(nil), "init manager")
t.Cleanup(func() {
require.NoError(t, m.Close(nil), "close manager")
})
return m
}
func dnatV4(port uint16) fw.ForwardRule {
return fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{port}},
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
}
}
func dnatV6(port uint16) fw.ForwardRule {
return fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{port}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
}
}
// TestNftablesDNAT_RefcountBalancedV4 verifies that Add/Delete pairs leave the
// v4 refcount at zero.
func TestNftablesDNAT_RefcountBalancedV4(t *testing.T) {
m := newNftRefcountManager(t, false)
state := m.router.ipFwdState
r1, err := m.AddDNATRule(dnatV4(8081))
require.NoError(t, err, "add v4 dnat 1")
v4, v6 := state.Counts()
require.Equal(t, 1, v4, "v4 refcount after first add")
require.Equal(t, 0, v6, "v6 refcount unchanged")
r2, err := m.AddDNATRule(dnatV4(8082))
require.NoError(t, err, "add v4 dnat 2")
v4, v6 = state.Counts()
require.Equal(t, 2, v4, "v4 refcount after second add")
require.Equal(t, 0, v6, "v6 refcount unchanged")
require.NoError(t, m.DeleteDNATRule(r1), "delete v4 dnat 1")
v4, v6 = state.Counts()
require.Equal(t, 1, v4, "v4 refcount after first delete")
require.Equal(t, 0, v6, "v6 refcount unchanged")
require.NoError(t, m.DeleteDNATRule(r2), "delete v4 dnat 2")
v4, v6 = state.Counts()
require.Equal(t, 0, v4, "v4 refcount after second delete")
require.Equal(t, 0, v6, "v6 refcount unchanged")
}
// TestNftablesDNAT_RefcountBalancedV6 verifies the v6 path increments v6 only
// and decrements back to zero on Delete.
func TestNftablesDNAT_RefcountBalancedV6(t *testing.T) {
m := newNftRefcountManager(t, true)
require.NotNil(t, m.router6, "v6 router")
require.Same(t, m.router.ipFwdState, m.router6.ipFwdState, "shared state")
state := m.router.ipFwdState
r1, err := m.AddDNATRule(dnatV6(9091))
require.NoError(t, err, "add v6 dnat 1")
v4, v6 := state.Counts()
require.Equal(t, 0, v4, "v4 refcount unchanged")
require.Equal(t, 1, v6, "v6 refcount after first add")
r2, err := m.AddDNATRule(dnatV6(9092))
require.NoError(t, err, "add v6 dnat 2")
v4, v6 = state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 2, v6, "v6 refcount after second add")
require.NoError(t, m.DeleteDNATRule(r1), "delete v6 dnat 1")
v4, v6 = state.Counts()
require.Equal(t, 0, v4, "v4 refcount unchanged")
require.Equal(t, 1, v6, "v6 refcount after first delete")
require.NoError(t, m.DeleteDNATRule(r2), "delete v6 dnat 2")
v4, v6 = state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 0, v6, "v6 refcount after second delete")
}
// TestNftablesDNAT_DuplicateAddNoLeak verifies that a duplicate Add (same
// ForwardRule) does not double-increment the refcount.
func TestNftablesDNAT_DuplicateAddNoLeak(t *testing.T) {
m := newNftRefcountManager(t, true)
state := m.router.ipFwdState
rule := dnatV4(8083)
r1, err := m.AddDNATRule(rule)
require.NoError(t, err, "add v4 dnat")
v4, _ := state.Counts()
require.Equal(t, 1, v4)
// duplicate add: same rule ID, must be a no-op for the refcount.
_, err = m.AddDNATRule(rule)
require.NoError(t, err, "duplicate add")
v4, _ = state.Counts()
require.Equal(t, 1, v4, "duplicate add must not increment")
require.NoError(t, m.DeleteDNATRule(r1), "delete v4 dnat")
v4, _ = state.Counts()
require.Equal(t, 0, v4, "single delete must drop to zero")
}
// TestNftablesDNAT_DeleteMissingNoUnderflow verifies deleting a rule that was
// never added does not underflow the refcount.
func TestNftablesDNAT_DeleteMissingNoUnderflow(t *testing.T) {
m := newNftRefcountManager(t, true)
state := m.router.ipFwdState
// Construct a Rule reference for something never added. The router stores
// rules by ID(), and DeleteDNATRule looks them up in r.rules; a missing
// entry must be a no-op rather than calling Release.
phantom := dnatV4(8099)
require.NoError(t, m.DeleteDNATRule(&phantom), "delete missing v4 dnat")
v4, v6 := state.Counts()
require.Equal(t, 0, v4, "v4 refcount unaffected by missing delete")
require.Equal(t, 0, v6, "v6 refcount unaffected")
phantom6 := dnatV6(9099)
require.NoError(t, m.DeleteDNATRule(&phantom6), "delete missing v6 dnat")
v4, v6 = state.Counts()
require.Equal(t, 0, v4)
require.Equal(t, 0, v6, "v6 refcount unaffected by missing delete")
// And after a phantom delete, a real add still results in count=1.
r1, err := m.AddDNATRule(dnatV4(8100))
require.NoError(t, err, "add v4 dnat after phantom delete")
v4, _ = state.Counts()
require.Equal(t, 1, v4, "real add still increments after phantom delete")
require.NoError(t, m.DeleteDNATRule(r1))
}
// TestNftablesDNAT_DoubleDeleteNoUnderflow verifies that deleting the same rule
// twice does not underflow the refcount (the second delete is a no-op).
func TestNftablesDNAT_DoubleDeleteNoUnderflow(t *testing.T) {
m := newNftRefcountManager(t, true)
state := m.router.ipFwdState
r1, err := m.AddDNATRule(dnatV6(9093))
require.NoError(t, err)
_, v6 := state.Counts()
require.Equal(t, 1, v6)
require.NoError(t, m.DeleteDNATRule(r1), "first delete")
_, v6 = state.Counts()
require.Equal(t, 0, v6)
require.NoError(t, m.DeleteDNATRule(r1), "second delete must be no-op")
_, v6 = state.Counts()
require.Equal(t, 0, v6, "double delete must not underflow")
}

View File

@@ -105,8 +105,8 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
return fmt.Errorf("create v6 router: %w", err) return fmt.Errorf("create v6 router: %w", err)
} }
// Share the per-family forwarding refcounter with the v4 router so a v4 // Share the same IP forwarding state with the v4 router, since
// rule and a v6 rule against the same state machine cooperate cleanly. // EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
@@ -530,33 +530,17 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(false); err != nil { if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IPv4 forwarding: %w", err) return fmt.Errorf("enable IP forwarding: %w", err)
}
// v6 only when the overlay actually has v6.
if m.router6 == nil {
return nil
}
if err := m.router.ipFwdState.RequestForwarding(true); err != nil {
if rerr := m.router.ipFwdState.ReleaseForwarding(false); rerr != nil {
log.Warnf("rollback v4 forwarding: %v", rerr)
}
return fmt.Errorf("enable IPv6 forwarding: %w", err)
} }
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
var merr *multierror.Error if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.router.ipFwdState.ReleaseForwarding(false); err != nil { return fmt.Errorf("disable IP forwarding: %w", err)
merr = multierror.Append(merr, fmt.Errorf("disable IPv4 forwarding: %w", err))
} }
if m.router6 != nil { return nil
if err := m.router.ipFwdState.ReleaseForwarding(true); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable IPv6 forwarding: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer

View File

@@ -93,7 +93,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4), af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(wgIface.Name()), ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu, mtu: mtu,
} }
@@ -1550,6 +1550,10 @@ func (r *router) refreshRulesMap() error {
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID() ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists { if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil return rule, nil
@@ -1560,18 +1564,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
return nil, fmt.Errorf("convert protocol to number: %w", err) return nil, fmt.Errorf("convert protocol to number: %w", err)
} }
// Request forwarding before queueing rules: addDnatRedirect/addDnatMasq
// buffer netlink messages on r.conn that the next caller's Flush would
// commit if we returned without flushing them ourselves.
v6 := r.af.tableFamily == nftables.TableFamilyIPv6
if err := r.ipFwdState.RequestForwarding(v6); err != nil {
return nil, fmt.Errorf("enable forwarding: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil { if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
if rerr := r.ipFwdState.ReleaseForwarding(v6); rerr != nil {
log.Warnf("rollback forwarding refcount: %v", rerr)
}
return nil, err return nil, err
} }
@@ -1583,11 +1576,6 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
// TODO: find chains with drop policies and add rules there // TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
if rerr := r.ipFwdState.ReleaseForwarding(v6); rerr != nil {
log.Warnf("rollback forwarding refcount: %v", rerr)
}
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
return nil, fmt.Errorf("flush rules: %w", err) return nil, fmt.Errorf("flush rules: %w", err)
} }
@@ -1790,18 +1778,16 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
} }
func (r *router) DeleteDNATRule(rule firewall.Rule) error { func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID() ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
_, hadDNAT := r.rules[ruleKey+dnatSuffix]
_, hadSNAT := r.rules[ruleKey+snatSuffix]
if !hadDNAT && !hadSNAT {
return nil
}
var merr *multierror.Error var merr *multierror.Error
var needsFlush bool var needsFlush bool
@@ -1838,10 +1824,6 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
delete(r.rules, ruleKey+snatSuffix) delete(r.rules, ruleKey+snatSuffix)
} }
if err := r.ipFwdState.ReleaseForwarding(r.af.tableFamily == nftables.TableFamilyIPv6); err != nil {
log.Errorf("%v", err)
}
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -844,10 +844,6 @@ func collectSysctls() string {
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"}, []string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
listInterfaceSysctls("ipv4", "src_valid_mark")..., listInterfaceSysctls("ipv4", "src_valid_mark")...,
)) ))
writeSysctlGroup(&builder, "accept_ra", append(
[]string{"net.ipv6.conf.all.accept_ra", "net.ipv6.conf.default.accept_ra"},
listInterfaceSysctls("ipv6", "accept_ra")...,
))
writeSysctlGroup(&builder, "conntrack", []string{ writeSysctlGroup(&builder, "conntrack", []string{
"net.netfilter.nf_conntrack_acct", "net.netfilter.nf_conntrack_acct",
"net.netfilter.nf_conntrack_tcp_loose", "net.netfilter.nf_conntrack_tcp_loose",

View File

@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".": case entry.Pattern == ".":
return true return true
case entry.IsWildcard: case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") return strings.HasSuffix(qname, "."+entry.Pattern)
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default: default:
// For non-wildcard patterns: // For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match // If handler wants subdomain matching, allow suffix match

View File

@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: true, matchSubdomains: true,
shouldMatch: true, shouldMatch: true,
}, },
{
name: "wildcard label-boundary mismatch (suffix overlap)",
handlerDomain: "*.b.test.",
queryDomain: "x.ab.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard label-boundary match",
handlerDomain: "*.b.test.",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard multi-label match",
handlerDomain: "*.b.test.",
queryDomain: "x.y.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on multi-label apex",
handlerDomain: "*.b.test.",
queryDomain: "b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on unrelated suffix containment",
handlerDomain: "*.example.com.",
queryDomain: "notexample.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard accepts pattern registered without trailing dot",
handlerDomain: "*.b.test",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
expectedCalls: 1, expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called expectedHandler: 2, // highest priority matching handler should be called
}, },
{
name: "overlapping wildcard suffixes route to correct handler",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "app.ab.test.",
expectedCalls: 1,
expectedHandler: 1,
},
{ {
name: "root zone with specific domain", name: "root zone with specific domain",
handlers: []struct { handlers: []struct {

View File

@@ -26,6 +26,19 @@ type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
} }
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
// client knows about and whether that peer is currently connected. The
// local resolver uses this to suppress A/AAAA answers whose RDATA points
// at a disconnected peer (typical case: a synthesized private-service
// record pointing at an embedded proxy peer that just went offline).
//
// known=false means the IP isn't in the local peerstore at all — the
// record is left alone (it points at something outside our mesh, e.g.
// a non-peer upstream).
type PeerConnectivity interface {
IsConnectedByIP(ip string) (known, connected bool)
}
type Resolver struct { type Resolver struct {
mu sync.RWMutex mu sync.RWMutex
records map[dns.Question][]dns.RR records map[dns.Question][]dns.RR
@@ -33,6 +46,11 @@ type Resolver struct {
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone) // zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool zones map[domain.Domain]bool
resolver resolver resolver resolver
// peerConn, when non-nil, is consulted on every A/AAAA answer to
// drop records pointing at disconnected peers. nil disables the
// filter and preserves the legacy "return whatever is registered"
// behaviour for callers that never wire a status source.
peerConn PeerConnectivity
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
} }
} }
// SetPeerConnectivity wires the per-IP connectivity check used to filter
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
// Safe to call multiple times; the latest value wins.
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerConn = p
}
func (d *Resolver) MatchSubdomains() bool { func (d *Resolver) MatchSubdomains() bool {
return true return true
} }
@@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.RecursionAvailable = true replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question) result := d.lookupRecords(logger, question)
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
replyMessage.Authoritative = !result.hasExternalData replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result) replyMessage.Rcode = d.determineRcode(question, result)
@@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
} }
} }
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
// a known but disconnected peer. The synthesized private-service zones
// emit one A record per connected proxy peer in a cluster; when a peer
// goes offline, the server-side refresh removes the record from the
// next netmap, but the client may still hold the previous netmap for a
// short window. This filter is the local belt to that braces — even on
// the stale netmap, the resolver hides the offline target.
//
// Records pointing at unknown IPs (outside the local peerstore, e.g.
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
// through untouched.
//
// Escape hatch: if filtering would leave the answer empty AND at least
// one record was filtered, the original list is returned. Better to
// hand the client a record that may not respond than NXDOMAIN it
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
d.mu.RLock()
checker := d.peerConn
d.mu.RUnlock()
if checker == nil {
return records
}
kept := make([]dns.RR, 0, len(records))
var dropped int
for _, rr := range records {
ip := extractRecordIP(rr)
if ip == "" {
kept = append(kept, rr)
continue
}
known, connected := checker.IsConnectedByIP(ip)
if known && !connected {
dropped++
continue
}
kept = append(kept, rr)
}
if dropped == 0 {
return records
}
if len(kept) == 0 {
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
return records
}
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
return kept
}
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
// an A or AAAA record, or "" for any other record type.
func extractRecordIP(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
if r.A == nil {
return ""
}
return r.A.String()
case *dns.AAAA:
if r.AAAA == nil {
return ""
}
return r.AAAA.String()
}
return ""
}
// Update replaces all zones and their records // Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) { func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock() d.mu.Lock()

View File

@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return nil, nil return nil, nil
} }
// mockPeerConnectivity returns canned (known, connected) results per IP.
// Used by the disconnected-peer filter tests below. IPs not in the map
// are reported as unknown so the filter leaves them alone.
type mockPeerConnectivity struct {
byIP map[string]struct{ known, connected bool }
}
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
v, ok := m.byIP[ip]
if !ok {
return false, false
}
return v.known, v.connected
}
func TestLocalResolver_ServeDNS(t *testing.T) { func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{ recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.", Name: "peera.netbird.cloud.",
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver.isInManagedZone(qname) resolver.isInManagedZone(qname)
} }
} }
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
// connectivity-aware filtering layered on top of lookupRecords:
// when an A record's IP belongs to a known peer that's disconnected,
// the record is dropped from the answer. Records for unknown IPs pass
// through. If filtering would empty the answer entirely and at least
// one record was dropped, the original list is restored (escape hatch
// for the "all proxies offline" case).
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
zone := "svc.cluster.netbird."
connectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.10",
}
disconnectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.11",
}
unknownRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "203.0.113.5",
}
type ipState struct{ known, connected bool }
tests := []struct {
name string
records []nbdns.SimpleRecord
connByIP map[string]ipState
wantInOrder []string
}{
{
name: "drops disconnected peer, keeps connected",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: true},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.10"},
},
{
name: "unknown IPs pass through untouched",
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"203.0.113.5"},
},
{
name: "all disconnected falls back to original list",
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: false},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
},
{
name: "no checker wired returns all records",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
if tc.connByIP != nil {
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
for ip, st := range tc.connByIP {
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
}
resolver.SetPeerConnectivity(cm)
}
resolver.Update([]nbdns.CustomZone{{
Domain: strings.TrimSuffix(zone, "."),
Records: tc.records,
NonAuthoritative: true,
}})
var got *dns.Msg
writer := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
got = m
return nil
},
}
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
resolver.ServeDNS(writer, req)
require.NotNil(t, got, "resolver must produce a response")
require.Len(t, got.Answer, len(tc.wantInOrder),
"answer count must match expected: %v", tc.wantInOrder)
for i, want := range tc.wantInOrder {
a, ok := got.Answer[i].(*dns.A)
require.True(t, ok, "answer[%d] must be an A record", i)
assert.Equal(t, want, a.A.String(),
"answer[%d] expected %s got %s", i, want, a.A.String())
}
})
}
}

View File

@@ -301,6 +301,11 @@ func newDefaultServer(
warningDelayBase: defaultWarningDelayBase, warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1), healthRefresh: make(chan struct{}, 1),
} }
// Wire the local resolver against the peer status recorder so it can
// suppress A/AAAA answers that point at disconnected peers (typical
// case: synthesised private-service records pointing at an embedded
// proxy peer that just went offline).
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
// register with root zone, handler chain takes care of the routing // register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain) dnsService.RegisterMux(".", handlerChain)
@@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
} }
return nil return nil
} }
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
// the local resolver can ask "is this IP a known peer and is it
// connected?" without taking on the peer package as a dependency.
// A nil status recorder always reports known=false so the resolver
// short-circuits to the legacy "return everything" path.
type localPeerConnectivity struct {
status *peer.Status
}
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
if l.status == nil {
return false, false
}
state, ok := l.status.PeerStateByIP(ip)
if !ok {
return false, false
}
return true, state.ConnStatus == peer.StatusConnected
}

View File

@@ -61,9 +61,11 @@ import (
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
types "github.com/netbirdio/netbird/shared/management/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
@@ -202,6 +204,13 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
// latestComponents is the most-recent NetworkMapComponents decoded from
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
// NetworkMap that Calculate() produced from it so future incremental
// updates have a base to apply changes against. nil for legacy-format
// peers. Guarded by syncMsgMux.
latestComponents *types.NetworkMapComponents
networkMonitor *networkmonitor.NetworkMonitor networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer sshServer sshServer
@@ -865,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err() return e.ctx.Err()
} }
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { // Envelope sync responses carry PeerConfig at the top level; legacy
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate) // NetworkMap syncs carry it under NetworkMap.PeerConfig.
if pc := update.GetPeerConfig(); pc != nil {
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
} }
if update.GetNetbirdConfig() != nil { if update.GetNetbirdConfig() != nil {
@@ -907,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
nm := update.GetNetworkMap() var (
nm *mgmProto.NetworkMap
components *types.NetworkMapComponents
)
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
// Components-format peer: decode the envelope back to typed
// components, run Calculate() locally, and convert to the wire
// NetworkMap shape the rest of the engine consumes. Components are
// retained so future incremental updates can apply deltas instead
// of doing a full reconstruction.
localKey := e.config.WgPrivateKey.PublicKey().String()
dnsName := ""
if pc := update.GetPeerConfig(); pc != nil {
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
// shared domain by stripping the peer's own label prefix. Falls
// back to empty if the FQDN doesn't have the expected shape.
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
}
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
if err != nil {
return fmt.Errorf("decode network map envelope: %w", err)
}
nm = result.NetworkMap
components = result.Components
} else {
nm = update.GetNetworkMap()
}
if nm == nil { if nm == nil {
return nil return nil
} }
// Only retain the components view when the server sent the envelope
// path. A legacy proto.NetworkMap means components == nil; writing it
// here would clobber a previously-cached snapshot, breaking the
// incremental-delta base on a future envelope sync.
if components != nil {
e.latestComponents = components
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux. // Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too. // Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock() e.syncRespMux.RLock()
@@ -937,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil return nil
} }
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
// receiving peer's FQDN — the same value the management server fills as
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
// "netbird.cloud". An empty string is returned for unrecognized formats.
func extractDNSDomainFromFQDN(fqdn string) string {
for i := 0; i < len(fqdn); i++ {
if fqdn[i] == '.' && i+1 < len(fqdn) {
return fqdn[i+1:]
}
}
return ""
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error { func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil { if update != nil {
// when we receive token we expect valid address list too // when we receive token we expect valid address list too
@@ -1967,6 +2027,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics return e.clientMetrics
} }
// Performance bundles runtime-adjustable tunnel pool knobs.
// See Engine.SetPerformance. Nil fields are ignored.
type Performance struct {
PreallocatedBuffersPerPool *uint32
}
// SetPerformance applies the given tuning to this engine's live Device.
func (e *Engine) SetPerformance(t Performance) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -66,8 +66,8 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client" mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)

View File

@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
switch msg.Type { switch msg.Type {
// handle route changes // handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE: case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n]) route, flags, err := parseRouteMessage(buf[:n])
if err != nil { if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err) log.Debugf("Network monitor: error parsing routing message: %v", err)
continue continue
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
} }
switch msg.Type { switch msg.Type {
case unix.RTM_ADD: case unix.RTM_ADD:
if systemops.IgnoreAddedDefaultRoute(flags) {
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
continue
}
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil return nil
case unix.RTM_DELETE: case unix.RTM_DELETE:
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
} }
} }
func parseRouteMessage(buf []byte) (*systemops.Route, error) { func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err) return nil, 0, fmt.Errorf("parse RIB: %v", err)
} }
if len(msgs) != 1 { if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
} }
msg, ok := msgs[0].(*route.RouteMessage) msg, ok := msgs[0].(*route.RouteMessage)
if !ok { if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
} }
return systemops.MsgToRoute(msg) r, err := systemops.MsgToRoute(msg)
if err != nil {
return nil, 0, err
}
return r, msg.Flags, nil
} }
// waitReadable blocks until fd has data to read, or ctx is cancelled. // waitReadable blocks until fd has data to read, or ctx is cancelled.

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker" "github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
} }
// Fallback to deterministic key if no NetBird PSK is configured // Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey() determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
if err != nil { if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err) conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil return nil
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey return determKey
} }
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }

View File

@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan return s.eventsChan
} }
// Status holds a state of peers, signal, management connections and relays // Status holds a state of peers, signal, management connections and relays.
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
// every private-service request) don't contend against each other.
// Pure read methods take RLock; anything that mutates state takes Lock.
type Status struct { type Status struct {
mux sync.Mutex mux sync.RWMutex
peers map[string]State peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool signalState bool
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
// GetPeer adds peer to Daemon status map // GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (State, error) { func (d *Status) GetPeer(peerPubKey string) (State, error) {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
state, ok := d.peers[peerPubKey] state, ok := d.peers[peerPubKey]
if !ok { if !ok {
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
} }
func (d *Status) PeerByIP(ip string) (string, bool) { func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
for _, state := range d.peers { for _, state := range d.peers {
if state.IP == ip { if state.IP == ip {
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
return "", false return "", false
} }
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
// RemovePeer removes peer from Daemon status map // RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error { func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock() d.mux.Lock()
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
// GetLocalPeerState returns the local peer state // GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState { func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
return d.localPeer.Clone() return d.localPeer.Clone()
} }
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
} }
func (d *Status) GetRosenpassState() RosenpassState { func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
return RosenpassState{ return RosenpassState{
d.rosenpassEnabled, d.rosenpassEnabled,
d.rosenpassPermissive, d.rosenpassPermissive,
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
} }
func (d *Status) GetLazyConnection() bool { func (d *Status) GetLazyConnection() bool {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
return d.lazyConnectionEnabled return d.lazyConnectionEnabled
} }
func (d *Status) GetManagementState() ManagementState { func (d *Status) GetManagementState() ManagementState {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
return ManagementState{ return ManagementState{
d.mgmAddress, d.mgmAddress,
d.managementState, d.managementState,
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
// IsLoginRequired determines if a peer's login has expired. // IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool { func (d *Status) IsLoginRequired() bool {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
// if peer is connected to the management then login is not expired // if peer is connected to the management then login is not expired
if d.managementState { if d.managementState {
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
} }
func (d *Status) GetSignalState() SignalState { func (d *Status) GetSignalState() SignalState {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
return SignalState{ return SignalState{
d.signalAddress, d.signalAddress,
d.signalState, d.signalState,
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states // GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult { func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
if d.relayMgr == nil { if d.relayMgr == nil {
return d.relayStates return d.relayStates
} }
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
} }
func (d *Status) ForwardingRules() []firewall.ForwardRule { func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
if d.ingressGwMgr == nil { if d.ingressGwMgr == nil {
return nil return nil
} }
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
} }
func (d *Status) GetDNSStates() []NSGroupState { func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
// shallow copy is good enough, as slices fields are currently not updated // shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates) return slices.Clone(d.nsGroupStates)
} }
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
return maps.Clone(d.resolvedDomainsStates) return maps.Clone(d.resolvedDomainsStates)
} }
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
LazyConnectionEnabled: d.GetLazyConnection(), LazyConnectionEnabled: d.GetLazyConnection(),
} }
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
fullStatus.LocalPeerState = d.localPeer fullStatus.LocalPeerState = d.localPeer
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
} }
func (d *Status) PeersStatus() (*configurer.Stats, error) { func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.Lock() d.mux.RLock()
defer d.mux.Unlock() defer d.mux.RUnlock()
if d.wgIface == nil { if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status") return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
} }

View File

@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal") assert.Equal(t, ip, state.IP, "ip should be equal")
} }
func TestStatus_PeerStateByIP(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
state, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "known tunnel IP should resolve to a peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
_, ok = status.PeerStateByIP("100.64.0.99")
req.False(ok, "unknown IP must report ok=false")
}
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
state, ok := status.PeerStateByIP("fd00::1")
req.True(ok, "IPv6-only match must resolve to the peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) { func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc" key := "abc"
fqdn := "peer-a.netbird.local" fqdn := "peer-a.netbird.local"

View File

@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil)) return hex.EncodeToString(hasher.Sum(nil))
} }
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct { type Manager struct {
ifaceName string ifaceName string
spk []byte spk []byte
@@ -36,7 +45,7 @@ type Manager struct {
preSharedKey *[32]byte preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler rpWgHandler *NetbirdHandler
server *rp.Server server rpServer
lock sync.Mutex lock sync.Mutex
port int port int
wgIface PresharedKeySetter wgIface PresharedKeySetter
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public) rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash) log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
} }
func (m *Manager) GetPubKey() []byte { func (m *Manager) GetPubKey() []byte {
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server // addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error { func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey} pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil { if m.preSharedKey != nil {
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil { if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err) return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
} }
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
} }
peerID, err := m.server.AddPeer(pcfg) peerID, err := m.server.AddPeer(pcfg)
if err != nil { if err != nil {
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
return err return err
} }
m.server, err = rp.NewUDPServer(conf) server, err := rp.NewUDPServer(conf)
if err != nil { if err != nil {
return err return err
} }
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port) log.Infof("starting rosenpass server on port %d", m.port)
return m.server.Run() return server.Run()
} }
// Close closes the Rosenpass server // Close closes the Rosenpass server
func (m *Manager) Close() error { func (m *Manager) Close() error {
if m.server != nil { m.lock.Lock()
err := m.server.Close() server := m.server
if err != nil { m.server = nil
log.Errorf("failed closing local rosenpass server") m.lock.Unlock()
} if server == nil {
m.server = nil return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
} }
return nil return nil
} }

View File

@@ -1,14 +1,412 @@
package rosenpass package rosenpass
import ( import (
"errors"
"os"
"sync"
"testing" "testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) { func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort() port, err := findRandomAvailableUDPPort()
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, port, 0) require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535) require.LessOrEqual(t, port, 65535)
} }
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -0,0 +1,42 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -0,0 +1,44 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -2,109 +2,54 @@ package ipfwdstate
import ( import (
"fmt" "fmt"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
// IPForwardingState tracks v4 and v6 IP-forwarding sysctl enables with // IPForwardingState is a struct that keeps track of the IP forwarding state.
// independent refcounts so a v4-only routing setup doesn't flip v6 sysctls. // todo: read initial state of the IP forwarding from the system and reset the state based on it.
// todo: separate v4/v6 forwarding state, since the sysctls are independent
// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables
// manager shares one instance between both routers, which works only because
// EnableIPForwarding enables both sysctls in a single call.
type IPForwardingState struct { type IPForwardingState struct {
mu sync.Mutex enabledCounter int
v4Count int
v6Count int
wgIfaceName string
v6Saved map[string]int
} }
func NewIPForwardingState(wgIfaceName string) *IPForwardingState { func NewIPForwardingState() *IPForwardingState {
return &IPForwardingState{wgIfaceName: wgIfaceName} return &IPForwardingState{}
} }
// Counts returns the current v4 and v6 refcounts. Intended for diagnostics func (f *IPForwardingState) RequestForwarding() error {
// and tests. if f.enabledCounter != 0 {
func (f *IPForwardingState) Counts() (v4, v6 int) { f.enabledCounter++
f.mu.Lock()
defer f.mu.Unlock()
return f.v4Count, f.v6Count
}
// RequestForwarding enables the family's forwarding sysctl on first request.
func (f *IPForwardingState) RequestForwarding(v6 bool) error {
f.mu.Lock()
defer f.mu.Unlock()
if v6 {
return f.requestV6()
}
return f.requestV4()
}
// ReleaseForwarding decrements the family counter. The last v6 release restores
// what enable captured. v4 stays on: net.ipv4.ip_forward is co-owned by other
// tooling (docker, k8s, libvirt).
func (f *IPForwardingState) ReleaseForwarding(v6 bool) error {
f.mu.Lock()
defer f.mu.Unlock()
if v6 {
return f.releaseV6()
}
f.releaseV4()
return nil
}
func (f *IPForwardingState) requestV4() error {
if f.v4Count == 0 {
if err := systemops.EnableV4IPForwarding(); err != nil {
return fmt.Errorf("enable IPv4 forwarding: %w", err)
}
log.Info("IPv4 forwarding enabled")
}
f.v4Count++
return nil
}
func (f *IPForwardingState) releaseV4() {
if f.v4Count > 0 {
f.v4Count--
}
}
func (f *IPForwardingState) requestV6() error {
if f.v6Count == 0 {
saved, err := systemops.EnableV6IPForwarding(f.wgIfaceName)
if err != nil {
if rerr := systemops.DisableV6IPForwarding(saved); rerr != nil {
log.Warnf("rollback partial v6 sysctls: %v", rerr)
}
return fmt.Errorf("enable IPv6 forwarding: %w", err)
}
f.v6Saved = saved
log.Info("IPv6 forwarding enabled")
}
f.v6Count++
return nil
}
func (f *IPForwardingState) releaseV6() error {
if f.v6Count == 0 {
return nil
}
f.v6Count--
if f.v6Count > 0 {
return nil return nil
} }
saved := f.v6Saved if err := systemops.EnableIPForwarding(); err != nil {
f.v6Saved = nil return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
if err := systemops.DisableV6IPForwarding(saved); err != nil {
return fmt.Errorf("disable IPv6 forwarding: %w", err)
} }
log.Info("IPv6 forwarding disabled") f.enabledCounter = 1
log.Info("IP forwarding enabled")
return nil
}
func (f *IPForwardingState) ReleaseForwarding() error {
if f.enabledCounter == 0 {
return nil
}
if f.enabledCounter > 1 {
f.enabledCounter--
return nil
}
// if failed to disable IP forwarding we anyway decrement the counter
f.enabledCounter = 0
// todo call systemops.DisableIPForwarding()
return nil return nil
} }

View File

@@ -0,0 +1,9 @@
//go:build dragonfly || freebsd || netbsd || openbsd
package systemops
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor.
func IgnoreAddedDefaultRoute(flags int) bool {
return filterRoutesByFlags(flags)
}

View File

@@ -0,0 +1,21 @@
//go:build darwin
package systemops
import "golang.org/x/sys/unix"
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor. Scoped routes
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
// must not trigger an engine restart.
func IgnoreAddedDefaultRoute(flags int) bool {
if filterRoutesByFlags(flags) {
return true
}
if flags&unix.RTF_IFSCOPE != 0 {
return true
}
return false
}

View File

@@ -32,17 +32,8 @@ func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
return nil return nil
} }
func EnableV4IPForwarding() error { func EnableIPForwarding() error {
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func EnableV6IPForwarding(string) (map[string]int, error) {
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
return map[string]int{}, nil
}
func DisableV6IPForwarding(map[string]int) error {
return nil return nil
} }

View File

@@ -58,17 +58,8 @@ func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
return nil return nil
} }
func EnableV4IPForwarding() error { func EnableIPForwarding() error {
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func EnableV6IPForwarding(string) (map[string]int, error) {
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
return map[string]int{}, nil
}
func DisableV6IPForwarding(map[string]int) error {
return nil return nil
} }

View File

@@ -763,10 +763,13 @@ func flushRoutes(tableID, family int) error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func EnableV4IPForwarding() error { func EnableIPForwarding() error {
if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil { if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
return err return err
} }
if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil {
log.Warnf("failed to enable IPv6 forwarding: %v", err)
}
return nil return nil
} }

View File

@@ -43,17 +43,8 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
func EnableV4IPForwarding() error { func EnableIPForwarding() error {
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func EnableV6IPForwarding(string) (map[string]int, error) {
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
return map[string]int{}, nil
}
func DisableV6IPForwarding(map[string]int) error {
return nil return nil
} }

View File

@@ -1,82 +0,0 @@
//go:build !android
package systemops
import (
"fmt"
"net"
"os"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
)
const (
// 1 (default) accepts RAs only while forwarding is off; 2 keeps RA
// acceptance on regardless, so RA-installed host defaults survive our
// v6 forwarding flip.
acceptRAInterfacePath = "net.ipv6.conf.%s.accept_ra"
acceptRAProcPathFormat = "/proc/sys/net/ipv6/conf/%s/accept_ra"
)
// EnableV6IPForwarding bumps accept_ra=2 on host v6 interfaces before flipping
// forwarding=1, so RA-installed host defaults survive. Returns the prior values
// of sysctls we actually changed; entries already at the target are omitted.
func EnableV6IPForwarding(wgIfaceName string) (map[string]int, error) {
saved := map[string]int{}
bumpAcceptRA(saved, wgIfaceName)
oldVal, err := sysctl.Set(ipv6ForwardingPath, 1, false)
if err != nil {
return saved, err
}
if oldVal != 1 {
saved[ipv6ForwardingPath] = oldVal
}
return saved, nil
}
// DisableV6IPForwarding restores what EnableV6IPForwarding captured.
func DisableV6IPForwarding(saved map[string]int) error {
var result *multierror.Error
for key, value := range saved {
if _, err := sysctl.Set(key, value, false); err != nil {
result = multierror.Append(result, fmt.Errorf("restore %s: %w", key, err))
}
}
return nberrors.FormatErrorOrNil(result)
}
func bumpAcceptRA(saved map[string]int, wgIfaceName string) {
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("list interfaces for accept_ra: %v", err)
return
}
for _, intf := range interfaces {
if intf.Name == "lo" || intf.Name == wgIfaceName {
continue
}
bumpAcceptRAForInterface(saved, intf.Name)
}
}
func bumpAcceptRAForInterface(saved map[string]int, name string) {
key := fmt.Sprintf(acceptRAInterfacePath, name)
// Build procfs path from name, not the dotted key: VLAN names like eth0.100.
if _, err := os.Stat(fmt.Sprintf(acceptRAProcPathFormat, name)); err != nil {
return
}
// onlyIfOne=true: leave admin overrides (0, 2) alone.
oldVal, err := sysctl.Set(key, 2, true)
if err != nil {
log.Warnf("bump %s: %v", key, err)
return
}
if oldVal != 2 {
saved[key] = oldVal
}
}

View File

@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() cancel := m.cancel
done := m.done
m.mu.Unlock()
if m.cancel == nil { if cancel == nil {
return nil return nil
} }
m.cancel() cancel()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-m.done: case <-done:
} }
return nil return nil

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)

View File

@@ -3,15 +3,14 @@
package system package system
import ( import (
"bytes"
"context" "context"
"os" "os"
"os/exec"
"regexp" "regexp"
"runtime" "runtime"
"strings"
"time" "time"
"golang.org/x/sys/unix"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo" "github.com/zcalusic/sysinfo"
@@ -29,19 +28,11 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
info := _getInfo() kernelName, kernelVersion, kernelPlatform := kernelInfo()
for strings.Contains(info, "broken pipe") {
info = _getInfo()
time.Sleep(500 * time.Millisecond)
}
osStr := strings.ReplaceAll(info, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
osName, osVersion := readOsReleaseFile() osName, osVersion := readOsReleaseFile()
if osName == "" { if osName == "" {
osName = osInfo[3] osName = kernelName
} }
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
@@ -58,8 +49,8 @@ func GetInfo(ctx context.Context) *Info {
} }
gio := &Info{ gio := &Info{
Kernel: osInfo[0], Kernel: kernelName,
Platform: osInfo[2], Platform: kernelPlatform,
OS: osName, OS: osName,
OSVersion: osVersion, OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname), Hostname: extractDeviceName(ctx, systemHostname),
@@ -67,7 +58,7 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(), NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx), UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1], KernelVersion: kernelVersion,
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName, SystemProductName: si.SystemProductName,
@@ -78,18 +69,12 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
func _getInfo() string { func kernelInfo() (string, string, string) {
cmd := exec.Command("uname", "-srio") var uts unix.Utsname
cmd.Stdin = strings.NewReader("some") if err := unix.Uname(&uts); err != nil {
var out bytes.Buffer return "", "", ""
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Warnf("getInfo: %s", err)
} }
return out.String() return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
} }
func sysInfo() (string, string, string) { func sysInfo() (string, string, string) {

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"sync"
"syscall/js" "syscall/js"
"time" "time"
@@ -13,7 +14,7 @@ import (
) )
const ( const (
certValidationTimeout = 60 * time.Second certValidationTimeout = 5 * time.Minute
) )
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool) resultChan := make(chan bool, 1)
errorChan := make(chan error) errorChan := make(chan error, 1)
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { // Release from inside the callbacks so a post-timeout promise resolution
result := args[0].Bool() // does not invoke an already-released func.
resultChan <- result var thenFn, catchFn js.Func
var releaseOnce sync.Once
release := func() {
releaseOnce.Do(func() {
thenFn.Release()
catchFn.Release()
})
}
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
defer release()
resultChan <- args[0].Bool()
return nil return nil
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { })
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
defer release()
errorChan <- fmt.Errorf("certificate validation failed") errorChan <- fmt.Errorf("certificate validation failed")
return nil return nil
})) })
promise.Call("then", thenFn).Call("catch", catchFn)
select { select {
case result := <-resultChan: case result := <-resultChan:

View File

@@ -11,6 +11,7 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"sync/atomic"
"syscall/js" "syscall/js"
"time" "time"
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
} }
activeConnections map[string]*proxyConnection activeConnections map[string]*proxyConnection
destinations map[string]string destinations map[string]string
pendingHandlers map[string]js.Func
nextID atomic.Uint64
mu sync.Mutex mu sync.Mutex
} }
@@ -66,8 +69,15 @@ type proxyConnection struct {
rdpConn net.Conn rdpConn net.Conn
tlsConn *tls.Conn tlsConn *tls.Conn
wsHandlers js.Value wsHandlers js.Value
ctx context.Context // Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
cancel context.CancelFunc // global handle map and MUST be released, otherwise every connection
// leaks the Go memory the closure captures.
wsHandlerFn js.Func
onMessageFn js.Func
onCloseFn js.Func
cleanupOnce sync.Once
ctx context.Context
cancel context.CancelFunc
} }
// NewRDCleanPathProxy creates a new RDCleanPath proxy // NewRDCleanPathProxy creates a new RDCleanPath proxy
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
} }
} }
// CreateProxy creates a new proxy endpoint for the given destination // CreateProxy creates a new proxy endpoint for the given destination.
// The registered handler fn and its destinations/pendingHandlers entries are
// only released once a connection is established and cleanupConnection runs.
// If a caller invokes CreateProxy but never connects to the returned URL,
// those entries stay pinned for the lifetime of the page.
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := net.JoinHostPort(hostname, port) destination := net.JoinHostPort(hostname, port)
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
resolve := args[0] resolve := args[0]
go func() { go func() {
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
p.mu.Lock() p.mu.Lock()
if p.destinations == nil { if p.destinations == nil {
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy // Register the WebSocket handler for this specific proxy
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 { if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument") return js.ValueOf("error: requires WebSocket argument")
} }
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
ws := args[0] ws := args[0]
p.HandleWebSocketConnection(ws, proxyID) p.HandleWebSocketConnection(ws, proxyID)
return nil return nil
})) })
p.mu.Lock()
if p.pendingHandlers == nil {
p.pendingHandlers = make(map[string]js.Func)
}
p.pendingHandlers[proxyID] = handlerFn
p.mu.Unlock()
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL) resolve.Invoke(proxyURL)
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
p.mu.Lock() p.mu.Lock()
p.activeConnections[proxyID] = conn p.activeConnections[proxyID] = conn
if fn, ok := p.pendingHandlers[proxyID]; ok {
conn.wsHandlerFn = fn
delete(p.pendingHandlers, proxyID)
}
p.mu.Unlock() p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn) p.setupWebSocketHandlers(ws, conn)
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
} }
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 { if len(args) < 1 {
return nil return nil
} }
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
data := args[0] data := args[0]
go p.handleWebSocketMessage(conn, data) go p.handleWebSocketMessage(conn, data)
return nil return nil
})) })
ws.Set("onGoMessage", conn.onMessageFn)
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript") log.Debug("WebSocket closed by JavaScript")
conn.cancel() conn.cancel()
return nil return nil
})) })
ws.Set("onGoClose", conn.onCloseFn)
} }
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
} }
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
log.Debugf("Cleaning up connection %s", conn.id) conn.cleanupOnce.Do(func() {
conn.cancel() log.Debugf("Cleaning up connection %s", conn.id)
if conn.tlsConn != nil { conn.cancel()
log.Debug("Closing TLS connection") if conn.tlsConn != nil {
if err := conn.tlsConn.Close(); err != nil { log.Debug("Closing TLS connection")
log.Debugf("Error closing TLS connection: %v", err) if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
} }
conn.tlsConn = nil if conn.rdpConn != nil {
} log.Debug("Closing TCP connection")
if conn.rdpConn != nil { if err := conn.rdpConn.Close(); err != nil {
log.Debug("Closing TCP connection") log.Debugf("Error closing TCP connection: %v", err)
if err := conn.rdpConn.Close(); err != nil { }
log.Debugf("Error closing TCP connection: %v", err) conn.rdpConn = nil
} }
conn.rdpConn = nil js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
}
p.mu.Lock() // Detach before releasing so late JS calls surface as TypeError instead
delete(p.activeConnections, conn.id) // of silent "call to released function".
p.mu.Unlock() if conn.wsHandlers.Truthy() {
conn.wsHandlers.Set("onGoMessage", js.Undefined())
conn.wsHandlers.Set("onGoClose", js.Undefined())
}
// wsHandlerFn may be zero-value if the pending handler lookup missed.
if conn.wsHandlerFn.Truthy() {
conn.wsHandlerFn.Release()
}
if conn.onMessageFn.Truthy() {
conn.onMessageFn.Release()
}
if conn.onCloseFn.Truthy() {
conn.onCloseFn.Release()
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
delete(p.destinations, conn.id)
delete(p.pendingHandlers, conn.id)
p.mu.Unlock()
})
} }
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {

View File

@@ -13,7 +13,7 @@ import (
func CreateJSInterface(client *Client) js.Value { func CreateJSInterface(client *Client) js.Value {
jsInterface := js.Global().Get("Object").Call("create", js.Null()) jsInterface := js.Global().Get("Object").Call("create", js.Null())
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 { if len(args) < 1 {
return js.ValueOf(false) return js.ValueOf(false)
} }
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
_, err := client.Write(bytes) _, err := client.Write(bytes)
return js.ValueOf(err == nil) return js.ValueOf(err == nil)
})) })
jsInterface.Set("write", writeFunc)
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 { if len(args) < 2 {
return js.ValueOf(false) return js.ValueOf(false)
} }
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
rows := args[1].Int() rows := args[1].Int()
err := client.Resize(cols, rows) err := client.Resize(cols, rows)
return js.ValueOf(err == nil) return js.ValueOf(err == nil)
})) })
jsInterface.Set("resize", resizeFunc)
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
client.Close() client.Close()
return js.Undefined() return js.Undefined()
})) })
jsInterface.Set("close", closeFunc)
go readLoop(client, jsInterface) go func() {
readLoop(client, jsInterface)
// Detach before releasing so late JS calls surface as TypeError instead
// of silent "call to released function".
jsInterface.Set("write", js.Undefined())
jsInterface.Set("resize", js.Undefined())
jsInterface.Set("close", js.Undefined())
writeFunc.Release()
resizeFunc.Release()
closeFunc.Release()
}()
return jsInterface return jsInterface
} }

View File

@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
} }
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil { if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)") log.Infof("Relay WebSocket handler added (path: /relay)")
} }
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
} }
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic // createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn listener.Conn) var relayAcceptFn func(conn listener.Conn)
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
http.Error(w, "Relay service not enabled", http.StatusNotFound) http.Error(w, "Relay service not enabled", http.StatusNotFound)
} }
// Embedded IdP (Dex)
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
idpHandler.ServeHTTP(w, r)
// Management HTTP API (default) // Management HTTP API (default)
default: default:
httpHandler.ServeHTTP(w, r) httpHandler.ServeHTTP(w, r)

View File

@@ -53,6 +53,9 @@ type NameServerGroup struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
// Name group name // Name group name
Name string Name string
// Description group description // Description group description

12
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
go 1.25.5 go 1.25.5
require ( require (
cunicu.li/go-rosenpass v0.4.0 cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0 github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4 github.com/golang/protobuf v1.5.4
@@ -19,8 +19,8 @@ require (
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0 golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0 golang.org/x/sys v0.43.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0 google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11 google.golang.org/protobuf v1.36.11
@@ -38,7 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3 github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0 github.com/cilium/ebpf v0.19.0
github.com/coder/websocket v1.8.14 github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.18.0 github.com/coreos/go-oidc/v3 v3.18.0
@@ -60,7 +60,7 @@ require (
github.com/google/go-cmp v0.7.0 github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0 github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.1.1 github.com/gopacket/gopacket v1.4.0
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

26
go.sum
View File

@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA= codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -499,8 +501,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ= github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=

View File

@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
if file == "" { if file == "" {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config") return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
} }
return (&sql.SQLite3{File: file}).Open(logger) return newSQLite3(file).Open(logger)
case "postgres": case "postgres":
dsn, _ := s.Config["dsn"].(string) dsn, _ := s.Config["dsn"].(string)
if dsn == "" { if dsn == "" {

View File

@@ -20,7 +20,6 @@ import (
"github.com/dexidp/dex/server" "github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer" "github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4" jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@@ -77,7 +76,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Initialize SQLite storage // Initialize SQLite storage
dbPath := filepath.Join(config.DataDir, "oidc.db") dbPath := filepath.Join(config.DataDir, "oidc.db")
sqliteConfig := &sql.SQLite3{File: dbPath} sqliteConfig := newSQLite3(dbPath)
stor, err := sqliteConfig.Open(logger) stor, err := sqliteConfig.Open(logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open storage: %w", err) return nil, fmt.Errorf("failed to open storage: %w", err)

15
idp/dex/sqlite_cgo.go Normal file
View File

@@ -0,0 +1,15 @@
//go:build cgo
package dex
import (
sql "github.com/dexidp/dex/storage/sql"
)
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
// struct that takes a File path. Non-CGO builds get an empty stub whose
// Open() returns the dex "SQLite not available" error — correct behaviour
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
func newSQLite3(file string) *sql.SQLite3 {
return &sql.SQLite3{File: file}
}

15
idp/dex/sqlite_nocgo.go Normal file
View File

@@ -0,0 +1,15 @@
//go:build !cgo
package dex
import (
sql "github.com/dexidp/dex/storage/sql"
)
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
// Open() returns an error documenting the missing CGO support — correct
// behaviour for cross-compiled artefacts that never actually run the
// embedded IdP. The `file` argument is ignored.
func newSQLite3(_ string) *sql.SQLite3 {
return &sql.SQLite3{}
}

View File

@@ -55,6 +55,12 @@ type Controller struct {
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
// componentsDisabled, when true, forces the controller to emit legacy
// proto.NetworkMap to every peer regardless of capability. Set once at
// construction and never written after — readers race-free without a
// mutex.
componentsDisabled bool
} }
type bufferUpdate struct { type bufferUpdate struct {
@@ -81,12 +87,27 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
settingsManager: settingsManager, settingsManager: settingsManager,
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
config: config, config: config,
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
proxyController: proxyController, proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager, EphemeralPeersManager: ephemeralPeersManager,
} }
} }
// PeerNeedsComponents reports whether the gRPC layer should emit the
// component-based wire format for this peer.
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
}
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
// literal.
func parseBoolEnv(key string) bool {
v, _ := strconv.ParseBool(os.Getenv(key))
return v
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) { func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil { if err != nil {
@@ -112,7 +133,7 @@ func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams() return c.peersUpdateManager.CountStreams()
} }
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
@@ -175,6 +196,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
continue continue
} }
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
}
wg.Add(1) wg.Add(1)
semaphore <- struct{}{} semaphore <- struct{}{}
go func(p *nbpeer.Peer) { go func(p *nbpeer.Peer) {
@@ -192,18 +217,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start)) c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now() start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap := proxyNetworkMaps[p.ID]
if ok { if result.NetworkMap != nil && proxyNetworkMap != nil {
remotePeerNetworkMap.Merge(proxyNetworkMap) result.NetworkMap.Merge(proxyNetworkMap)
} }
peerGroups := account.GetPeerGroups(p.ID) peerGroups := account.GetPeerGroups(p.ID)
start = time.Now() start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) var update *proto.SyncResponse
if result.IsComponents() {
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
// the client merges it into Calculate()'s output the same
// way the legacy server did via NetworkMap.Merge.
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
}
c.metrics.CountToSyncResponseDuration(time.Since(start)) c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
@@ -242,14 +275,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
go func() { go func() {
defer b.mu.Unlock() defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() { if !b.update.Load() {
return return
} }
b.update.Store(false) b.update.Store(false)
if b.next == nil { if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
}) })
return return
} }
@@ -265,7 +298,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
if c.accountManagerMetrics != nil { if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
} }
return c.sendUpdateAccountPeers(ctx, accountID) return c.sendUpdateAccountPeers(ctx, accountID, reason)
} }
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
@@ -314,11 +347,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err return err
} }
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap := proxyNetworkMaps[peer.ID]
if ok { if result.NetworkMap != nil && proxyNetworkMap != nil {
remotePeerNetworkMap.Merge(proxyNetworkMap) result.NetworkMap.Merge(proxyNetworkMap)
} }
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID) extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
@@ -329,7 +362,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
peerGroups := account.GetPeerGroups(peerId) peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) var update *proto.SyncResponse
if result.IsComponents() {
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
}
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{ c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update, Update: update,
MessageType: network_map.MessageTypeNetworkMap, MessageType: network_map.MessageTypeNetworkMap,
@@ -359,14 +397,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
go func() { go func() {
defer b.mu.Unlock() defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() { if !b.update.Load() {
return return
} }
b.update.Store(false) b.update.Store(false)
if b.next == nil { if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
}) })
return return
} }
@@ -376,6 +414,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil return nil
} }
// GetValidatedPeerWithComponents is the components-format counterpart of
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
// encodes both into the wire envelope. Callers must gate on capability
// themselves before dispatching here — this method does NOT branch on it.
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, nil, 0, err
}
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
return nil, nil, nil, nil, 0, err
}
// Fetch the proxy network map fragment for this peer alongside the
// components — same single-account-load path the streaming controller
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
// instead of waiting for the next streaming push.
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval { if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID) network, err := c.repo.GetAccountNetwork(ctx, accountID)

View File

@@ -22,6 +22,10 @@ type Controller interface {
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
// PeerNeedsComponents combines the peer's advertised capability with the
// kill-switch flag — the only public predicate gRPC layers should ask.
PeerNeedsComponents(p *nbpeer.Peer) bool
GetDNSDomain(settings *types.Settings) string GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context) StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
} }
// GetValidatedPeerWithComponents mocks base method.
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMapComponents)
ret2, _ := ret[2].(*types.NetworkMap)
ret3, _ := ret[3].([]*posture.Checks)
ret4, _ := ret[4].(int64)
ret5, _ := ret[5].(error)
return ret0, ret1, ret2, ret3, ret4, ret5
}
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
}
// PeerNeedsComponents mocks base method.
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
ret0, _ := ret[0].(bool)
return ret0
}
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
}
// OnPeerConnected mocks base method. // OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) { func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
found = true found = true
select { select {
case channel <- update: case channel <- update:
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
default: default:
dropped = true dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel)) log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))

View File

@@ -5,6 +5,7 @@ package peers
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
@@ -35,6 +36,14 @@ type Manager interface {
SetAccountManager(accountManager account.Manager) SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error) GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
// Returns nil with an error when no match exists. No permission check;
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
// to. Used by the proxy's auth path to authorise a request by the calling
// peer's group memberships.
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
} }
type managerImpl struct { type managerImpl struct {
@@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
} }
// GetPeerByTunnelIP delegates to the store's indexed lookup.
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
}
// GetPeerWithGroups returns the peer plus its group memberships. Any store
// error returns (nil, nil, err) so callers never receive a valid peer
// alongside a non-nil error.
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
return p, groups, nil
}
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {

View File

@@ -6,6 +6,7 @@ package peers
import ( import (
context "context" context "context"
net "net"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -13,6 +14,7 @@ import (
account "github.com/netbirdio/netbird/management/server/account" account "github.com/netbirdio/netbird/management/server/account"
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer" peer "github.com/netbirdio/netbird/management/server/peer"
types "github.com/netbirdio/netbird/management/server/types"
) )
// MockManager is a mock of Manager interface. // MockManager is a mock of Manager interface.
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder return m.recorder
} }
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}
// DeletePeers mocks base method. // DeletePeers mocks base method.
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
} }
// GetPeerByTunnelIP mocks base method.
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
}
// GetPeerID mocks base method. // GetPeerID mocks base method.
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) { func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
} }
// GetPeerWithGroups mocks base method.
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].([]*types.Group)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
}
// GetPeersByGroupIDs mocks base method. // GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
} }
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -23,6 +23,8 @@ type Domain struct {
// SupportsCrowdSec is populated at query time from proxy cluster capabilities. // SupportsCrowdSec is populated at query time from proxy cluster capabilities.
// Not persisted. // Not persisted.
SupportsCrowdSec *bool `gorm:"-"` SupportsCrowdSec *bool `gorm:"-"`
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
SupportsPrivate *bool `gorm:"-"`
} }
// EventMeta returns activity event metadata for a domain // EventMeta returns activity event metadata for a domain

View File

@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
SupportsCustomPorts: d.SupportsCustomPorts, SupportsCustomPorts: d.SupportsCustomPorts,
RequireSubdomain: d.RequireSubdomain, RequireSubdomain: d.RequireSubdomain,
SupportsCrowdsec: d.SupportsCrowdSec, SupportsCrowdsec: d.SupportsCrowdSec,
SupportsPrivate: d.SupportsPrivate,
} }
if d.TargetCluster != "" { if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster resp.TargetCluster = &d.TargetCluster

View File

@@ -35,6 +35,7 @@ type proxyManager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
} }
type Manager struct { type Manager struct {
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster) d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster) d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster) d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
ret = append(ret, d) ret = append(ret, d)
} }
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
if d.TargetCluster != "" { if d.TargetCluster != "" {
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster) cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster) cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
} }
// Custom domains never require a subdomain by default since // Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain. // the account owns them and should be able to use the bare domain.

View File

@@ -10,7 +10,7 @@ import (
) )
type mockProxyManager struct { type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error) getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
} }
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil return nil
} }
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) { func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{ pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result) assert.Equal(t, []string{"byop.example.com"}, result)
} }

View File

@@ -19,6 +19,7 @@ type Manager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error) CountAccountProxies(ctx context.Context, accountID string) (int64, error)

View File

@@ -21,6 +21,7 @@ type store interface {
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr) return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
} }
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration // CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
} }
return nil return nil
} }

View File

@@ -15,16 +15,16 @@ import (
) )
type mockStore struct { type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
} }
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool { func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil return nil
} }
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager { func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test") meter := noop.NewMeterProvider().Meter("test")

View File

@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
} }
// ClusterSupportsPrivate mocks base method.
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
}
// Connect mocks base method. // Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) { func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -20,6 +20,9 @@ type Capabilities struct {
RequireSubdomain *bool RequireSubdomain *bool
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured. // SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
SupportsCrowdsec *bool SupportsCrowdsec *bool
// Private indicates whether this proxy supports inbound access via Wireguard
// tunnel and netbird-only authentication policies
Private *bool
} }
// Proxy represents a reverse proxy instance // Proxy represents a reverse proxy instance
@@ -67,10 +70,9 @@ type Cluster struct {
Type ClusterType Type ClusterType
Online bool Online bool
ConnectedProxies int ConnectedProxies int
// Capability flags. *bool because nil means "no proxy reported a // *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
// capability for this cluster" — the dashboard renders these as
// unknown rather than false.
SupportsCustomPorts *bool SupportsCustomPorts *bool
RequireSubdomain *bool RequireSubdomain *bool
SupportsCrowdSec *bool SupportsCrowdSec *bool
Private *bool
} }

View File

@@ -204,6 +204,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
SupportsCustomPorts: c.SupportsCustomPorts, SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain, RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdSec, SupportsCrowdsec: c.SupportsCrowdSec,
Private: c.Private,
}) })
} }

View File

@@ -82,6 +82,7 @@ type CapabilityProvider interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
} }
type Manager struct { type Manager struct {
@@ -136,6 +137,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address) clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address) clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address) clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
} }
return clusters, nil return clusters, nil
@@ -208,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
target.Host = resource.Domain target.Host = resource.Domain
case service.TargetTypeSubnet: case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource // For subnets we do not do any lookups on the resource
case service.TargetTypeCluster:
// Cluster targets carry the upstream address on target_id; the
// proxy resolves the destination at request time.
default: default:
return fmt.Errorf("unknown target type: %s", target.TargetType) return fmt.Errorf("unknown target type: %s", target.TargetType)
} }
@@ -779,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil { if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err return err
} }
case service.TargetTypeCluster:
if err := validateClusterTarget(target); err != nil {
return err
}
default: default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
} }
@@ -786,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil return nil
} }
func validateClusterTarget(target *service.Target) error {
if !target.Options.DirectUpstream {
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
}
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
@@ -962,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
return fmt.Errorf("failed to get services: %w", err) return fmt.Errorf("failed to get services: %w", err)
} }
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, s := range services { for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
} }
return nil return nil

View File

@@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) {
}) })
} }
} }
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
// No peer or resource lookups must be issued for cluster targets.
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Options: rpservice.TargetOptions{DirectUpstream: true},
},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
}
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
// store-side check that cluster targets must opt into the host-stack dial
// path. Without DirectUpstream the proxy would route this target through
// the embedded NetBird client and fail on every request.
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "backend.lan",
},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
require.Error(t, err, "cluster target without direct_upstream must be rejected")
assert.ErrorContains(t, err, "direct upstream disabled")
}
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mgr := &Manager{store: mockStore}
svc := &rpservice.Service{
ID: "svc-1",
AccountID: accountID,
Targets: []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "127.0.0.1",
},
},
}
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
}

View File

@@ -45,10 +45,11 @@ const (
StatusCertificateFailed Status = "certificate_failed" StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error" StatusError Status = "error"
TargetTypePeer TargetType = "peer" TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host" TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain" TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet" TargetTypeSubnet TargetType = "subnet"
TargetTypeCluster TargetType = "cluster"
SourcePermanent = "permanent" SourcePermanent = "permanent"
SourceEphemeral = "ephemeral" SourceEphemeral = "ephemeral"
@@ -60,6 +61,11 @@ type TargetOptions struct {
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"` SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
// the target via the proxy host's network stack. Useful for upstreams
// reachable without WireGuard (public APIs, LAN services, localhost
// sidecars). Default false.
DirectUpstream bool `json:"direct_upstream,omitempty"`
} }
type Target struct { type Target struct {
@@ -67,7 +73,7 @@ type Target struct {
AccountID string `gorm:"index:idx_target_account;not null" json:"-"` AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"` Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored Host string `json:"host"`
Port uint16 `gorm:"index:idx_target_port" json:"port"` Port uint16 `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"` TargetId string `gorm:"index:idx_target_id" json:"target_id"`
@@ -200,6 +206,10 @@ type Service struct {
Mode string `gorm:"default:'http'"` Mode string `gorm:"default:'http'"`
ListenPort uint16 ListenPort uint16
PortAutoAssigned bool PortAutoAssigned bool
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
Private bool
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
} }
// InitNewRecord generates a new unique ID and resets metadata for a newly created // InitNewRecord generates a new unique ID and resets metadata for a newly created
@@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service {
Mode: &mode, Mode: &mode,
ListenPort: &listenPort, ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned, PortAutoAssigned: &s.PortAutoAssigned,
Private: &s.Private,
}
if len(s.AccessGroups) > 0 {
groups := append([]string(nil), s.AccessGroups...)
resp.AccessGroups = &groups
} }
if s.ProxyCluster != "" { if s.ProxyCluster != "" {
@@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp return resp
} }
// ToProtoMapping converts the service into the wire format the proxy consumes.
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := s.buildPathMappings() pathMappings := s.buildPathMappings()
@@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
RewriteRedirects: s.RewriteRedirects, RewriteRedirects: s.RewriteRedirects,
Mode: s.Mode, Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec ListenPort: int32(s.ListenPort), //nolint:gosec
Private: s.Private,
} }
if r := restrictionsToProto(s.Restrictions); r != nil { if r := restrictionsToProto(s.Restrictions); r != nil {
@@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
} }
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
return nil return nil
} }
apiOpts := &api.ServiceTargetOptions{} apiOpts := &api.ServiceTargetOptions{}
@@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if len(opts.CustomHeaders) > 0 { if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders apiOpts.CustomHeaders = &opts.CustomHeaders
} }
if opts.DirectUpstream {
apiOpts.DirectUpstream = &opts.DirectUpstream
}
return apiOpts return apiOpts
} }
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 { if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
return nil return nil
} }
popts := &proto.PathTargetOptions{ popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify, SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite), PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders, CustomHeaders: opts.CustomHeaders,
DirectUpstream: opts.DirectUpstream,
} }
if opts.RequestTimeout != 0 { if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout) popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
@@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
if o.CustomHeaders != nil { if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders opts.CustomHeaders = *o.CustomHeaders
} }
if o.DirectUpstream != nil {
opts.DirectUpstream = *o.DirectUpstream
}
return opts, nil return opts, nil
} }
@@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
if req.ListenPort != nil { if req.ListenPort != nil {
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
} }
if req.Private != nil {
s.Private = *req.Private
}
if req.AccessGroups != nil {
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
} else {
s.AccessGroups = nil
}
targets, err := targetsFromAPI(accountID, req.Targets) targets, err := targetsFromAPI(accountID, req.Targets)
if err != nil { if err != nil {
@@ -740,6 +775,9 @@ func (s *Service) Validate() error {
if err := validateAccessRestrictions(&s.Restrictions); err != nil { if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err return err
} }
if err := s.validatePrivateRequirements(); err != nil {
return err
}
switch s.Mode { switch s.Mode {
case ModeHTTP: case ModeHTTP:
@@ -753,6 +791,23 @@ func (s *Service) Validate() error {
} }
} }
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
func (s *Service) validatePrivateRequirements() error {
if !s.Private {
return nil
}
if s.Mode != "" && s.Mode != ModeHTTP {
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
}
if len(s.AccessGroups) == 0 {
return errors.New("private services require at least one access group")
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
}
return nil
}
func (s *Service) validateHTTPMode() error { func (s *Service) validateHTTPMode() error {
if s.Domain == "" { if s.Domain == "" {
return errors.New("service domain is required") return errors.New("service domain is required")
@@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error {
for i, target := range s.Targets { for i, target := range s.Targets {
switch target.TargetType { switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain: case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored // Host is normally overwritten by replaceHostByLookup with the
// resolved peer IP / resource address; operator-supplied values
// are honored only when DirectUpstream is set. Validate the
// override here so misconfigured hosts fail fast at API time.
if err := validateDirectUpstreamHost(i, target); err != nil {
return err
}
case TargetTypeSubnet: case TargetTypeSubnet:
if target.Host == "" { if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType) return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
} }
case TargetTypeCluster:
if err := validateClusterTarget(i, target); err != nil {
return err
}
default: default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType) return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
} }
@@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error {
return nil return nil
} }
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
func validateClusterTarget(idx int, target *Target) error {
host := strings.TrimSpace(target.Host)
if host == "" {
return fmt.Errorf("target %d: has empty host", idx)
}
if !target.Options.DirectUpstream {
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
}
return validateDirectUpstreamHost(idx, target)
}
// validateDirectUpstreamHost validates the operator-supplied Host on a
// peer/host/domain target when DirectUpstream is set. Empty Host is
// allowed — the lookup fills in the default peer IP / resource address.
// Without DirectUpstream the Host value is silently overwritten by
// replaceHostByLookup, so we don't validate it (preserves the historical
// behaviour where APIs accepted any value and dropped it). Non-empty
// Host with DirectUpstream must look like a hostname or IP and must
// not carry a port (port lives on Target.Port).
func validateDirectUpstreamHost(idx int, target *Target) error {
if !target.Options.DirectUpstream {
return nil
}
host := strings.TrimSpace(target.Host)
if host == "" {
return nil
}
if strings.ContainsAny(host, " \t/") {
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
}
if _, _, err := net.SplitHostPort(host); err == nil {
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
}
return nil
}
func (s *Service) validateL4Target(target *Target) error { func (s *Service) validateL4Target(target *Target) error {
// L4 services have a single target; per-target disable is meaningless // L4 services have a single target; per-target disable is meaningless
// (use the service-level Enabled flag instead). Force it on so that // (use the service-level Enabled flag instead). Force it on so that
// buildPathMappings always includes the target in the proto. // buildPathMappings always includes the target in the proto.
target.Enabled = true target.Enabled = true
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
if target.TargetId == "" { if target.TargetId == "" {
return errors.New("target_id is required for L4 services") return errors.New("target_id is required for L4 services")
} }
if target.TargetType != TargetTypeCluster && target.Port == 0 {
return errors.New("target port is required for L4 services")
}
switch target.TargetType { switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain: case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// OK if err := validateDirectUpstreamHost(0, target); err != nil {
return err
}
case TargetTypeSubnet: case TargetTypeSubnet:
if target.Host == "" { if target.Host == "" {
return errors.New("target host is required for subnet targets") return errors.New("target host is required for subnet targets")
} }
case TargetTypeCluster:
// target_id carries the cluster address; the proxy resolves
// the upstream at request time.
default: default:
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType) return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
} }
@@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service {
} }
} }
var accessGroups []string
if len(s.AccessGroups) > 0 {
accessGroups = append([]string(nil), s.AccessGroups...)
}
return &Service{ return &Service{
ID: s.ID, ID: s.ID,
AccountID: s.AccountID, AccountID: s.AccountID,
@@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service {
Mode: s.Mode, Mode: s.Mode,
ListenPort: s.ListenPort, ListenPort: s.ListenPort,
PortAutoAssigned: s.PortAutoAssigned, PortAutoAssigned: s.PortAutoAssigned,
Private: s.Private,
AccessGroups: accessGroups,
} }
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) {
assert.Contains(t, err.Error(), "exceeds maximum length") assert.Contains(t, err.Error(), "exceeds maximum length")
}) })
} }
func TestValidate_HTTPClusterTarget(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
}
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
}
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
// rule that operator-supplied Host is mandatory: cluster targets dial the
// upstream via the host network stack (direct_upstream is implied), so an
// empty Host leaves the proxy with nothing to dial.
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
}
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
// half of the cluster-target rule: DirectUpstream must be true so the
// stdlib transport branch in MultiTransport is taken. Without it the
// embedded NetBird client would try to dial the cluster address through
// the WG tunnel, which is the wrong network for a cluster upstream.
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
}
func TestValidate_L4ClusterTarget(t *testing.T) {
rp := validProxy()
rp.Mode = ModeTCP
rp.ListenPort = 9000
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
}
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
svc := validProxy()
svc.Private = true
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
cp := svc.Copy()
require.NotNil(t, cp)
assert.True(t, cp.Private)
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
cp.Private = false
assert.True(t, svc.Private)
cp.AccessGroups[0] = "grp-other"
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
}
func TestService_APIRoundtrip_Private(t *testing.T) {
enabled := true
private := true
accessGroups := []string{"grp-admins"}
targets := []api.ServiceTarget{{
TargetId: "eu.proxy.netbird.io",
TargetType: api.ServiceTargetTargetType("cluster"),
Protocol: "http",
Port: 80,
Enabled: true,
}}
req := &api.ServiceRequest{
Name: "svc-private",
Domain: "myapp.eu.proxy.netbird.io",
Enabled: enabled,
Private: &private,
AccessGroups: &accessGroups,
Targets: &targets,
}
svc := &Service{}
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
assert.True(t, svc.Private)
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
resp := svc.ToAPIResponse()
require.NotNil(t, resp.Private)
assert.True(t, *resp.Private)
require.NotNil(t, resp.AccessGroups)
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
}
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "access group")
}
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-sso"},
}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
}
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Mode = ModeTCP
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "HTTP")
}

View File

@@ -20,6 +20,20 @@ type KeyPair struct {
type Claims struct { type Claims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
Method auth.Method `json:"method"` Method auth.Method `json:"method"`
// Email is the calling user's email address. Carried so the
// proxy can stamp identity on upstream requests (e.g.
// x-litellm-end-user-id) without an extra management
// round-trip on every cookie-bearing request.
Email string `json:"email,omitempty"`
// Groups carries the user's group IDs so the proxy can stamp them
// onto upstream requests (X-NetBird-Groups) from the cookie path
// without an extra management round-trip.
Groups []string `json:"groups,omitempty"`
// GroupNames carries the human-readable display names for the ids
// in Groups, ordered identically (positional pairing). Slice may be
// shorter than Groups for tokens minted before names were
// resolvable; the consumer falls back to ids for missing positions.
GroupNames []string `json:"group_names,omitempty"`
} }
func GenerateKeyPair() (*KeyPair, error) { func GenerateKeyPair() (*KeyPair, error) {
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
}, nil }, nil
} }
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) { // SignToken mints a session JWT for the given user and domain. email,
// groups, and groupNames, when non-empty, are embedded so the proxy can
// authorise and stamp identity for policy-aware middlewares without a
// management round-trip on every cookie-bearing request. groupNames
// pairs positionally with groups; pass nil when names couldn't be
// resolved.
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64) privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
if err != nil { if err != nil {
return "", fmt.Errorf("decode private key: %w", err) return "", fmt.Errorf("decode private key: %w", err)
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now),
}, },
Method: method, Method: method,
Email: email,
Groups: append([]string(nil), groups...),
GroupNames: append([]string(nil), groupNames...),
} }
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)

View File

@@ -10,8 +10,10 @@ import (
"slices" "slices"
"time" "time"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/cors"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -19,7 +21,6 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store" cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
@@ -27,16 +28,20 @@ import (
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/crypt" "github.com/netbirdio/netbird/util/crypt"
) )
const apiPrefix = "/api"
var ( var (
kaep = keepalive.EnforcementPolicy{ kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second, MinTime: 15 * time.Second,
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) EventStore() activity.Store {
return Create(s, func() activity.Store { return Create(s, func() activity.Store {
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics()) var err error
if err != nil { key := s.Config.DataStoreEncryptionKey
log.Fatalf("failed to initialize integration metrics: %v", err) if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = crypt.GenerateKey()
if err != nil {
log.Fatalf("failed to generate event store encryption key: %v", err)
}
} }
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
if err != nil { if err != nil {
log.Fatalf("failed to initialize event store: %v", err) log.Fatalf("failed to initialize event store: %v", err)
} }
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
}) })
} }
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
// the deployment isn't using the embedded variant.
func (s *BaseServer) IDPHandler() http.Handler {
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
if !ok || embeddedIdP == nil {
return nil
}
return cors.AllowAll().Handler(embeddedIdP.Handler())
}
func (s *BaseServer) Router() *mux.Router {
return Create(s, func() *mux.Router {
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter { return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv() cfg, enabled := middleware.RateLimiterConfigFromEnv()

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator( integratedPeerValidator, err := validator.NewIntegratedValidator(
context.Background(), context.Background(),
s.PeersManager(), s.PeersManager(),
s.SettingsManager(), s.SettingsManager(),

View File

@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager { func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager { return Create(s, func() permissions.Manager {
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) return permissions.NewManager(s.Store())
s.AfterInit(func(s *BaseServer) {
manager.SetAccountManager(s.AccountManager())
})
return manager
}) })
} }
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
return idpManager return idpManager
} }
return nil return nil
}) })
} }
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return &m return &m
}) })
} }
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
return false
}

View File

@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
} }
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
switch { switch {
case s.certManager != nil: case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once // a call to certManager.Listener() always creates a new listener so we do it once
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
log.Tracef("custom handler set successfully") log.Tracef("custom handler set successfully")
} }
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services) // Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok { if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok { if handler, ok := customHandler.(http.Handler); ok {
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
gRPCHandler.ServeHTTP(writer, request) gRPCHandler.ServeHTTP(writer, request)
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(writer, request) wsProxy.Handler().ServeHTTP(writer, request)
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
idpHandler.ServeHTTP(writer, request)
default: default:
httpHandler.ServeHTTP(writer, request) httpHandler.ServeHTTP(writer, request)
} }

View File

@@ -0,0 +1,813 @@
package grpc
import (
"encoding/base64"
"strconv"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// wgKeyRawLen is the raw byte length of a WireGuard public key.
const wgKeyRawLen = 32
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
// The envelope is fully self-contained — every field needed by the client's
// local Calculate() comes from the components struct itself. The only
// externally-supplied data is the receiving peer's PeerConfig (which is
// computed alongside the components in the network_map controller and reused
// from the legacy proto path) and the dns_domain string.
type ComponentsEnvelopeInput struct {
Components *types.NetworkMapComponents
PeerConfig *proto.PeerConfig
DNSDomain string
DNSForwarderPort int64
// UserIDClaim is the OIDC claim name the client should embed in
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
// is OK — client treats empty as "no SshAuth to build".
UserIDClaim string
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
// external controllers (BYOP/port-forwarding). Nil when no proxy data
// is present; encoder skips the field in that case.
ProxyPatch *proto.ProxyPatch
}
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
// wire envelope. The encoder is intentionally non-deterministic: it iterates
// Go maps in their native (random) order. Indexes inside the envelope
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
// are self-consistent within a single encode, so the decoder reconstructs
// the same typed objects regardless of emit order. Tests that need to
// compare envelopes do so semantically via proto round-trip + canonicalize,
// not byte-equal.
//
// Callers must NOT concatenate or merge envelopes from different encodes —
// index spaces are local to a single envelope.
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
c := in.Components
// Graceful degrade when components is nil — matches the legacy path's
// behaviour for missing/unvalidated peers (return a NetworkMap with only
// Network populated). The receiver gets an envelope it can decode
// without crashing; AccountSettings stays non-nil so client-side
// dereferences are safe.
if c == nil {
// Match legacy missing-peer minimum: a NetworkMap with only Network
// populated. The receiver gets enough to bootstrap (Network
// identifier, dns_domain, account_settings) and nothing else.
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{
Full: &proto.NetworkMapComponentsFull{
PeerConfig: in.PeerConfig,
DnsDomain: in.DNSDomain,
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
AccountSettings: &proto.AccountSettingsCompact{},
ProxyPatch: in.ProxyPatch,
},
},
}
}
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
// every regular peer (in c.Peers) must be indexed before any encoder
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
// peers that exist only in c.RouterPeers would silently lose their
// peer_index reference.
enc := newComponentEncoder(c)
enc.indexAllPeers()
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
// Phase 2: gather every policy that any consumer references (peer-pair
// policies + resource-only policies) so encodeResourcePoliciesMap can
// translate every *Policy pointer to a wire index.
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
policies, policyToIdxs := enc.encodePolicies(allPolicies)
// Phase 3: emit. Order of struct field expressions no longer matters:
// every encoder either reads from the dedup tables or works on
// independent input.
full := &proto.NetworkMapComponentsFull{
Serial: networkSerial(c.Network),
PeerConfig: in.PeerConfig,
Network: toAccountNetwork(c.Network),
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
ProxyPatch: in.ProxyPatch,
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
DnsDomain: in.DNSDomain,
CustomZoneDomain: c.CustomZoneDomain,
AgentVersions: enc.agentVersions,
Peers: enc.peers,
RouterPeerIndexes: routerIdxs,
Policies: policies,
Groups: enc.encodeGroups(),
Routes: enc.encodeRoutes(c.Routes),
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
AccountZones: encodeCustomZones(c.AccountZones),
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
}
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
}
}
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
// production path always populates c.Network, but the encoder is exported
// and a hand-built components struct may omit it.
func networkSerial(n *types.Network) uint64 {
if n == nil {
return 0
}
return n.CurrentSerial()
}
type componentEncoder struct {
components *types.NetworkMapComponents
peerOrder map[string]uint32
peers []*proto.PeerCompact
agentVersionOrder map[string]uint32
agentVersions []string
}
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
return &componentEncoder{
components: c,
peerOrder: make(map[string]uint32, len(c.Peers)),
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
agentVersionOrder: make(map[string]uint32),
}
}
func (e *componentEncoder) indexAllPeers() {
for _, p := range e.components.Peers {
if p == nil {
continue
}
e.appendPeer(p)
}
}
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
if idx, ok := e.peerOrder[p.ID]; ok {
return idx
}
idx := uint32(len(e.peers))
e.peerOrder[p.ID] = idx
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
return idx
}
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
if idx, ok := e.agentVersionOrder[v]; ok {
return idx
}
// Lazy-initialise the table with "" at index 0 so the empty string
// stays interchangeable with proto3's default uint32=0 — peers without
// a WtVersion don't force the table to materialise.
if v == "" {
idx := uint32(len(e.agentVersions))
if idx == 0 {
e.agentVersions = append(e.agentVersions, "")
}
e.agentVersionOrder[""] = idx
return idx
}
if len(e.agentVersions) == 0 {
e.agentVersions = append(e.agentVersions, "")
e.agentVersionOrder[""] = 0
}
idx := uint32(len(e.agentVersions))
e.agentVersionOrder[v] = idx
e.agentVersions = append(e.agentVersions, v)
return idx
}
// indexRouterPeers ensures every router peer is in the peer dedup table
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
// run before any encoder that resolves peer ids via e.peerOrder.
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
if len(routers) == 0 {
return nil
}
out := make([]uint32, 0, len(routers))
for _, p := range routers {
if p == nil {
continue
}
out = append(out, e.appendPeer(p))
}
return out
}
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
if len(e.components.Groups) == 0 {
return nil
}
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
for _, g := range e.components.Groups {
if !g.HasSeqID() {
continue
}
peerIdxs := make([]uint32, 0, len(g.Peers))
for _, peerID := range g.Peers {
if idx, ok := e.peerOrder[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
out = append(out, &proto.GroupCompact{
Id: g.AccountSeqID,
Name: g.Name,
PeerIndexes: peerIdxs,
})
}
return out
}
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
// list and a map from policy pointer to the indexes of its emitted rules in
// that list — used by encodeResourcePoliciesMap to translate
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
if len(policies) == 0 {
return nil, nil
}
out := make([]*proto.PolicyCompact, 0, len(policies))
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
for _, pol := range policies {
if !pol.HasSeqID() || !pol.Enabled {
continue
}
for _, r := range pol.Rules {
if r == nil || !r.Enabled {
continue
}
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
out = append(out, e.encodePolicyRule(pol, r))
}
}
return out, idxByPolicy
}
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
return &proto.PolicyCompact{
Id: pol.AccountSeqID,
Action: networkmap.GetProtoAction(string(r.Action)),
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
Bidirectional: r.Bidirectional,
Ports: portsToUint32(r.Ports),
PortRanges: portRangesToProto(r.PortRanges),
SourceGroupIds: e.groupSeqIDs(r.Sources),
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
AuthorizedUser: r.AuthorizedUser,
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
SourceResource: e.resourceToProto(r.SourceResource),
DestinationResource: e.resourceToProto(r.DestinationResource),
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
}
}
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
// dropping any group that has no seq id assigned.
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
if len(src) == 0 {
return nil
}
out := make([]uint32, 0, len(src))
for _, gid := range src {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
// unionPolicies merges c.Policies with every policy referenced by
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
// policies (relevant to a NetworkResource but not to peer-pair traffic)
// only live in ResourcePoliciesMap; without this union step they'd be lost
// from the wire and the client's resource-policy lookup would come back
// empty.
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
// Fast path: non-router peers have no resource-only policies, so the
// "union" is identical to `policies`. Skip the dedup map allocation.
if len(resourcePolicies) == 0 {
return policies
}
seen := make(map[*types.Policy]struct{}, len(policies))
out := make([]*types.Policy, 0, len(policies))
for _, p := range policies {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
for _, list := range resourcePolicies {
for _, p := range list {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
}
return out
}
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
// group xid → local-user names) to the wire form (map keyed by group
// account_seq_id → UserNameList). Groups without a seq id are dropped —
// matches how source/destination group references handle the same case.
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserNameList, len(m))
for groupID, names := range m {
seq, ok := e.groupSeq(groupID)
if !ok {
continue
}
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
}
return out
}
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
g, ok := e.components.Groups[groupID]
if !ok || !g.HasSeqID() {
return 0, false
}
return g.AccountSeqID, true
}
// resourceToProto translates types.Resource for the wire. For peer-typed
// resources the peer id is converted to a peer index into the envelope's
// peers array. For other resource types only the type string is shipped
// today (Calculate's resource-typed rule path consults SourceResource only
// for "peer" — other types fall through to group-based lookup).
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
if r.ID == "" && r.Type == "" {
return nil
}
out := &proto.ResourceCompact{Type: string(r.Type)}
if r.Type == types.ResourceTypePeer && r.ID != "" {
if idx, ok := e.peerOrder[r.ID]; ok {
out.PeerIndexSet = true
out.PeerIndex = idx
}
}
return out
}
// postureCheckSeqs translates a slice of posture-check xids to their
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
// lookup. Unresolvable xids are silently dropped — matches how group/peer
// references handle the same case.
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
return nil
}
out := make([]uint32, 0, len(xids))
for _, xid := range xids {
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
out = append(out, seq)
}
}
return out
}
// networkSeq translates a Network xid to its per-account integer id using
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
// the xid isn't known — callers decide whether to skip the parent record.
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
if xid == "" {
return 0, false
}
seq, ok := e.components.NetworkXIDToSeq[xid]
if !ok || seq == 0 {
return 0, false
}
return seq, true
}
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
if s == nil || len(s.DisabledManagementGroups) == 0 {
return nil
}
out := &proto.DNSSettingsCompact{
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
}
for _, gid := range s.DisabledManagementGroups {
if seq, ok := e.groupSeq(gid); ok {
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
}
}
return out
}
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
if len(routes) == 0 {
return nil
}
out := make([]*proto.RouteRaw, 0, len(routes))
for _, r := range routes {
if r == nil {
continue
}
rr := &proto.RouteRaw{
Id: r.AccountSeqID,
NetId: string(r.NetID),
Description: r.Description,
KeepRoute: r.KeepRoute,
NetworkType: int32(r.NetworkType),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
SkipAutoApply: r.SkipAutoApply,
Domains: r.Domains.ToPunycodeList(),
GroupIds: e.groupIDsToSeq(r.Groups),
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
}
if r.Network.IsValid() {
rr.NetworkCidr = r.Network.String()
}
if r.Peer != "" {
if idx, ok := e.peerOrder[r.Peer]; ok {
rr.PeerIndexSet = true
rr.PeerIndex = idx
}
}
out = append(out, rr)
}
return out
}
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
if len(groupIDs) == 0 {
return nil
}
out := make([]uint32, 0, len(groupIDs))
for _, gid := range groupIDs {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
if len(nsgs) == 0 {
return nil
}
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
for _, nsg := range nsgs {
if nsg == nil {
continue
}
entry := &proto.NameServerGroupRaw{
Id: nsg.AccountSeqID,
Name: nsg.Name,
Description: nsg.Description,
Nameservers: encodeNameServers(nsg.NameServers),
GroupIds: e.groupIDsToSeq(nsg.Groups),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
}
out = append(out, entry)
}
return out
}
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
if len(servers) == 0 {
return nil
}
out := make([]*proto.NameServer, 0, len(servers))
for _, s := range servers {
out = append(out, &proto.NameServer{
IP: s.IP.String(),
NSType: int64(s.NSType),
Port: int64(s.Port),
})
}
return out
}
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
if len(records) == 0 {
return nil
}
out := make([]*proto.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, &proto.SimpleRecord{
Name: r.Name,
Type: int64(r.Type),
Class: r.Class,
TTL: int64(r.TTL),
RData: r.RData,
})
}
return out
}
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
if len(zones) == 0 {
return nil
}
out := make([]*proto.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, &proto.CustomZone{
Domain: z.Domain,
Records: encodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
if len(resources) == 0 {
return nil
}
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
for _, r := range resources {
if r == nil {
continue
}
entry := &proto.NetworkResourceRaw{
Id: r.AccountSeqID,
Name: r.Name,
Description: r.Description,
Type: string(r.Type),
Address: r.Address,
DomainValue: r.Domain,
Enabled: r.Enabled,
}
if seq, ok := e.networkSeq(r.NetworkID); ok {
entry.NetworkSeq = seq
}
if r.Prefix.IsValid() {
entry.PrefixCidr = r.Prefix.String()
}
out = append(out, entry)
}
return out
}
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
if len(routersMap) == 0 {
return nil
}
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
for networkXID, routers := range routersMap {
if len(routers) == 0 {
continue
}
netSeq, ok := e.networkSeq(networkXID)
if !ok {
continue
}
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
for peerID, r := range routers {
if r == nil {
continue
}
entry := &proto.NetworkRouterEntry{
Id: r.AccountSeqID,
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
}
if idx, ok := e.peerOrder[peerID]; ok {
entry.PeerIndexSet = true
entry.PeerIndex = idx
}
entries = append(entries, entry)
}
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
}
return out
}
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
if len(rpm) == 0 {
return nil
}
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
// (small slice). Network resources without seq id are dropped, matching how
// other components-without-seq are silently filtered.
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
for _, r := range e.components.NetworkResources {
if r != nil && r.AccountSeqID != 0 {
resourceXIDToSeq[r.ID] = r.AccountSeqID
}
}
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
for resourceXID, policies := range rpm {
seq, ok := resourceXIDToSeq[resourceXID]
if !ok {
continue
}
idxs := make([]uint32, 0, len(policies)*2)
for _, pol := range policies {
idxs = append(idxs, policyToIdxs[pol]...)
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
}
return out
}
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserIDList, len(m))
for groupID, userIDs := range m {
seq, ok := e.groupSeq(groupID)
if !ok || len(userIDs) == 0 {
continue
}
out[seq] = &proto.UserIDList{UserIds: userIDs}
}
return out
}
func stringSetToSlice(s map[string]struct{}) []string {
if len(s) == 0 {
return nil
}
out := make([]string, 0, len(s))
for k := range s {
out = append(out, k)
}
return out
}
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.PeerIndexSet, len(m))
for checkXID, failedPeerIDs := range m {
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
if !ok || seq == 0 {
continue
}
idxs := make([]uint32, 0, len(failedPeerIDs))
for peerID := range failedPeerIDs {
if idx, ok := e.peerOrder[peerID]; ok {
idxs = append(idxs, idx)
}
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
}
return out
}
// toAccountSettingsCompact always returns a non-nil message — the client
// dereferences it unconditionally during Calculate(), so a nil here would
// crash the receiver. A missing types.AccountSettingsInfo on the server
// (which shouldn't happen in production but the encoder is exported)
// degrades to login_expiration_enabled = false, which makes
// LoginExpired() return false for every peer.
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
if s == nil {
return &proto.AccountSettingsCompact{}
}
return &proto.AccountSettingsCompact{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
}
}
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
if n == nil {
return nil
}
out := &proto.AccountNetwork{
Identifier: n.Identifier,
NetCidr: n.Net.String(),
Dns: n.Dns,
Serial: n.CurrentSerial(),
}
if len(n.NetV6.IP) > 0 {
out.NetV6Cidr = n.NetV6.String()
}
return out
}
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
pc := &proto.PeerCompact{
WgPubKey: decodeWgKey(p.Key),
SshPubKey: []byte(p.SSHKey),
DnsLabel: p.DNSLabel,
AgentVersionIdx: agentVersionIdx,
AddedWithSsoLogin: p.UserID != "",
LoginExpirationEnabled: p.LoginExpirationEnabled,
SshEnabled: p.SSHEnabled,
SupportsIpv6: p.SupportsIPv6(),
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
}
if p.LastLogin != nil {
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
}
switch {
case !p.IP.IsValid():
// leave Ip nil
case p.IP.Is4() || p.IP.Is4In6():
ip := p.IP.Unmap().As4()
pc.Ip = ip[:]
default:
ip := p.IP.As16()
pc.Ip = ip[:]
}
if p.IPv6.IsValid() {
ip := p.IPv6.As16()
pc.Ipv6 = ip[:]
}
return pc
}
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
// key, or nil for an empty / malformed key.
func decodeWgKey(s string) []byte {
if s == "" {
return nil
}
out := make([]byte, wgKeyRawLen)
n, err := base64.StdEncoding.Decode(out, []byte(s))
if err != nil || n != wgKeyRawLen {
return nil
}
return out
}
func portsToUint32(ports []string) []uint32 {
if len(ports) == 0 {
return nil
}
out := make([]uint32, 0, len(ports))
for _, p := range ports {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
continue
}
out = append(out, uint32(v))
}
return out
}
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
if len(ranges) == 0 {
return nil
}
out := make([]*proto.PortInfo_Range, 0, len(ranges))
for _, r := range ranges {
out = append(out, &proto.PortInfo_Range{
Start: uint32(r.Start),
End: uint32(r.End),
})
}
return out
}

View File

@@ -0,0 +1,879 @@
package grpc
import (
"bytes"
"cmp"
"net"
"net/netip"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
// to reference the new peer indexes. Groups, policies, and router indexes are
// also sorted. After canonicalize, two envelopes built from the same logical
// input compare byte-equal via proto.Equal.
//
// This lives on the test side — the encoder itself emits in map-iteration
// order. Test-side normalization is the contract for "two encodes are
// equivalent".
func canonicalize(full *proto.NetworkMapComponentsFull) {
if full == nil {
return
}
// Canonicalize agent_versions first: sort the slice and rewrite each
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
// index 0 by convention.
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
if len(full.AgentVersions) > 0 {
// Pair version → original index, sort, rebuild.
type avEntry struct {
version string
oldIdx uint32
}
entries := make([]avEntry, len(full.AgentVersions))
for i, v := range full.AgentVersions {
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
}
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
// keeps the canonicalize output stable when two entries compare
// equal (the encoder dedups, but defending against future inputs).
slices.SortFunc(entries, func(a, b avEntry) int {
if a.version == "" && b.version != "" {
return -1
}
if b.version == "" && a.version != "" {
return 1
}
if c := cmp.Compare(a.version, b.version); c != 0 {
return c
}
return cmp.Compare(a.oldIdx, b.oldIdx)
})
newVersions := make([]string, len(entries))
for newIdx, e := range entries {
avRemap[e.oldIdx] = uint32(newIdx)
newVersions[newIdx] = e.version
}
full.AgentVersions = newVersions
}
for _, p := range full.Peers {
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
p.AgentVersionIdx = newIdx
}
}
type peerEntry struct {
peer *proto.PeerCompact
oldIdx uint32
}
entries := make([]peerEntry, len(full.Peers))
for i, p := range full.Peers {
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
}
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
// nil from malformed keys, or both empty for placeholders).
slices.SortFunc(entries, func(a, b peerEntry) int {
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
return c
}
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
})
remap := make(map[uint32]uint32, len(entries))
newPeers := make([]*proto.PeerCompact, len(entries))
for newIdx, e := range entries {
remap[e.oldIdx] = uint32(newIdx)
newPeers[newIdx] = e.peer
}
full.Peers = newPeers
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
for _, g := range full.Groups {
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
}
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
for _, r := range full.Routes {
if r.PeerIndexSet {
if newIdx, ok := remap[r.PeerIndex]; ok {
r.PeerIndex = newIdx
}
}
slices.Sort(r.GroupIds)
slices.Sort(r.AccessControlGroupIds)
slices.Sort(r.PeerGroupIds)
}
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
for _, list := range full.RoutersMap {
for _, entry := range list.Entries {
if entry.PeerIndexSet {
if newIdx, ok := remap[entry.PeerIndex]; ok {
entry.PeerIndex = newIdx
}
}
slices.Sort(entry.PeerGroupIds)
}
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
}
for _, set := range full.PostureFailedPeers {
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
}
for _, p := range full.Policies {
slices.Sort(p.SourceGroupIds)
slices.Sort(p.DestinationGroupIds)
}
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
// multiple PolicyCompact entries sharing the same Id (one per rule, when
// a Policy has multiple rules) still get a deterministic order. After
// sorting we remap indexes in ResourcePoliciesMap.
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
for i, p := range full.Policies {
policyOldOrder[p] = uint32(i)
}
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
if c := cmp.Compare(a.Id, b.Id); c != 0 {
return c
}
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
return c
}
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
})
policyRemap := make(map[uint32]uint32, len(full.Policies))
for newIdx, p := range full.Policies {
policyRemap[policyOldOrder[p]] = uint32(newIdx)
}
for _, idxs := range full.ResourcePoliciesMap {
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
}
for _, list := range full.GroupIdToUserIds {
slices.Sort(list.UserIds)
}
slices.Sort(full.AllowedUserIds)
}
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
out := make([]uint32, 0, len(idxs))
for _, i := range idxs {
if newIdx, ok := remap[i]; ok {
out = append(out, newIdx)
}
}
slices.Sort(out)
return out
}
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
// the encoder is intentionally non-deterministic.
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
canonicalize(a.GetFull())
canonicalize(b.GetFull())
return goproto.Equal(a, b)
}
func newTestComponents() *types.NetworkMapComponents {
peerA := &nbpeer.Peer{
ID: "peer-a",
Key: testWgKeyA,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peera",
SSHKey: "ssh-a",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-b",
Key: testWgKeyB,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
DNSLabel: "peerb",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
}
peerC := &nbpeer.Peer{
ID: "peer-c",
Key: testWgKeyC,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
return &types.NetworkMapComponents{
PeerID: "peer-a",
Network: &types.Network{
Identifier: "net-test",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 7,
},
AccountSettings: &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 2 * time.Hour,
},
Peers: map[string]*nbpeer.Peer{
"peer-a": peerA,
"peer-b": peerB,
"peer-c": peerC,
},
Groups: map[string]*types.Group{
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
},
Policies: []*types.Policy{
{
ID: "pol-1",
AccountSeqID: 10,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
Ports: []string{"22", "80"},
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
},
},
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
}
}
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
c := newTestComponents()
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full, "envelope must contain Full payload")
assert.EqualValues(t, 7, full.Serial)
assert.Equal(t, "netbird.cloud", full.DnsDomain)
require.NotNil(t, full.Network)
assert.Equal(t, "net-test", full.Network.Identifier)
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
require.NotNil(t, full.AccountSettings)
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
require.Len(t, full.Peers, 3)
byLabel := map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
byLabel[p.DnsLabel] = p
}
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
}
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
// Hammer it 100 times — Go map iteration is randomized per call, so each
// run produces different wire bytes, but the canonicalized form must
// match.
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d must be semantically equivalent to first encode", i)
}
}
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines)
results := make([]*proto.NetworkMapEnvelope, goroutines)
for i := 0; i < goroutines; i++ {
i := i
go func() {
defer wg.Done()
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
}()
}
wg.Wait()
for i, got := range results {
require.NotNil(t, got, "goroutine %d returned nil", i)
require.True(t, envelopesEquivalent(expected, got),
"goroutine %d produced inequivalent envelope", i)
}
}
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 2)
groupByID := map[uint32]*proto.GroupCompact{}
for _, g := range full.Groups {
groupByID[g.Id] = g
}
require.Contains(t, groupByID, uint32(1))
require.Contains(t, groupByID, uint32(2))
assert.Equal(t, "Src", groupByID[1].Name)
assert.Equal(t, "Dst", groupByID[2].Name)
assert.Len(t, groupByID[1].PeerIndexes, 1)
assert.Len(t, groupByID[2].PeerIndexes, 2)
}
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.EqualValues(t, 10, pc.Id)
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
assert.True(t, pc.Bidirectional)
assert.Equal(t, []uint32{22, 80}, pc.Ports)
require.Len(t, pc.PortRanges, 1)
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.RouterPeerIndexes, 1)
idx := full.RouterPeerIndexes[0]
require.Less(t, int(idx), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
}
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
"two distinct versions, order depends on map iteration")
idxByLabel := map[string]uint32{}
for _, p := range full.Peers {
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
}
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
}
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
c := newTestComponents()
c.Policies[0].Enabled = false
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
c := newTestComponents()
c.Groups["group-src"].AccountSeqID = 0
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
assert.EqualValues(t, 2, full.Groups[0].Id)
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
// Both peers have nil WgPubKey after decode; canonicalize must still
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
// canonicalize identically.
c := newTestComponents()
c.Peers["peer-a"].Key = "garbage-a-!!!"
c.Peers["peer-b"].Key = "garbage-b-!!!"
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d with two same-key peers must canonicalize equivalently", i)
}
}
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
c := newTestComponents()
c.Peers["peer-a"].Key = "not-base64-!!!"
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Peers, 3)
var byLabel = map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
byLabel[p.DnsLabel] = p
}
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
}
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
c := newTestComponents()
v6Only := &nbpeer.Peer{
ID: "peer-v6",
Key: testWgKeyA,
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
DNSLabel: "peerv6",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.Peers["peer-v6"] = v6Only
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerv6" {
found = p
}
}
require.NotNil(t, found, "ipv6-only peer must be present")
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
assert.Len(t, found.Ipv6, 16)
}
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
c := newTestComponents()
c.Peers["peer-noip"] = &nbpeer.Peer{
ID: "peer-noip",
Key: testWgKeyA,
DNSLabel: "peernoip",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peernoip" {
found = p
}
}
require.NotNil(t, found)
assert.Empty(t, found.Ip)
assert.Empty(t, found.Ipv6)
}
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
}
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
full := env.GetFull()
require.NotNil(t, full)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Groups)
assert.Empty(t, full.Policies)
assert.Empty(t, full.RouterPeerIndexes)
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
}
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
c := newTestComponents()
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
c.Peers["peer-a"].UserID = "user-1"
c.Peers["peer-a"].LoginExpirationEnabled = true
c.Peers["peer-a"].LastLogin = &now
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var pa *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peera" {
pa = p
}
}
require.NotNil(t, pa)
assert.True(t, pa.AddedWithSsoLogin)
assert.True(t, pa.LoginExpirationEnabled)
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
// peer-b has no UserID and no LastLogin → all fields zero-value.
var pb *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerb" {
pb = p
}
}
require.NotNil(t, pb)
assert.False(t, pb.AddedWithSsoLogin)
assert.False(t, pb.LoginExpirationEnabled)
assert.Zero(t, pb.LastLoginUnixNano)
}
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{
{
ID: "route-peer",
AccountSeqID: 100,
NetID: "net-A",
Description: "via peer-c",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Peer: "peer-c", // peer ID, not WG key
Groups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
Enabled: true,
},
{
ID: "route-peergroup",
AccountSeqID: 101,
NetID: "net-B",
Network: netip.MustParsePrefix("10.1.0.0/16"),
PeerGroups: []string{"group-src", "group-dst"},
Enabled: true,
},
{
ID: "route-no-seq",
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
Network: netip.MustParsePrefix("10.2.0.0/16"),
Enabled: true,
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 3)
byNetID := map[string]*proto.RouteRaw{}
for _, r := range full.Routes {
byNetID[r.NetId] = r
}
r1 := byNetID["net-A"]
require.NotNil(t, r1)
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
require.Less(t, int(r1.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
assert.Empty(t, r1.PeerGroupIds)
r2 := byNetID["net-B"]
require.NotNil(t, r2)
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{{
ID: "route-x",
AccountSeqID: 100,
Peer: "peer-not-in-components",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Enabled: true,
}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 1)
assert.False(t, full.Routes[0].PeerIndexSet,
"missing peer reference must not pretend to point at peer index 0")
}
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
c := newTestComponents()
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
// is the I1 case — without unionPolicies the encoder would silently
// drop it from the wire.
resourceOnlyPolicy := &types.Policy{
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
}
c.ResourcePoliciesMap = map[string][]*types.Policy{
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
}
// Resource must appear in components.NetworkResources with a seq id —
// encoder uses that to translate the xid map key to uint32.
c.NetworkResources = []*resourceTypes.NetworkResource{
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
policyByID := map[uint32]*proto.PolicyCompact{}
policyIdxByID := map[uint32]uint32{}
for i, p := range full.Policies {
policyByID[p.Id] = p
policyIdxByID[p.Id] = uint32(i)
}
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
idxs := full.ResourcePoliciesMap[77].Indexes
require.Len(t, idxs, 2)
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
"resource policies map must reference both wire policy indexes")
}
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
c := newTestComponents()
c.NameServerGroups = []*nbdns.NameServerGroup{{
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
}},
Groups: []string{"group-src", "group-not-persisted"},
Primary: true, Enabled: true,
Domains: []string{"corp.example"},
}}
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.NameserverGroups, 1)
nsg := full.NameserverGroups[0]
assert.EqualValues(t, 50, nsg.Id)
assert.Equal(t, "Main", nsg.Name)
assert.True(t, nsg.Primary)
require.Len(t, nsg.Nameservers, 1)
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
}
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
c := newTestComponents()
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
c.PostureFailedPeers = map[string]map[string]struct{}{
"check-1": {
"peer-a": {},
"peer-b": {},
"peer-not-in-account": {},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.PostureFailedPeers, uint32(33))
idxs := full.PostureFailedPeers[33].PeerIndexes
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
}
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
c := newTestComponents()
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"peer-c": {
ID: "router-1", AccountSeqID: 200,
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
entries := full.RoutersMap[5].Entries
require.Len(t, entries, 1)
e := entries[0]
assert.EqualValues(t, 200, e.Id)
assert.True(t, e.PeerIndexSet)
require.Less(t, int(e.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
assert.True(t, e.Masquerade)
assert.EqualValues(t, 10, e.Metric)
assert.True(t, e.Enabled)
}
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
// peer_index reference must still resolve.
c := newTestComponents()
delete(c.Peers, "peer-c")
routerPeer := &nbpeer.Peer{
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
require.Len(t, full.RoutersMap[5].Entries, 1)
e := full.RoutersMap[5].Entries[0]
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
}
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
c := newTestComponents()
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.DnsSettings)
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
}
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
c := newTestComponents()
c.GroupIDToUserIDs = map[string][]string{
"group-src": {"user-1", "user-2"},
"group-no-seq": {"user-3"}, // group not persisted → drop
"group-missing": {"user-4"}, // group not in components → drop
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
require.Contains(t, full.GroupIdToUserIds, uint32(1))
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
}
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
}
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
nm := &types.NetworkMap{
Peers: []*nbpeer.Peer{{
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}},
FirewallRules: []*types.FirewallRule{{
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
}},
}
patch := toProxyPatch(nm, "netbird.cloud", false, false)
require.NotNil(t, patch)
assert.Len(t, patch.Peers, 1)
assert.Len(t, patch.FirewallRules, 1)
}
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
// pass-through in both encoder branches (normal path + nil-Components
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
// from one of the envelope struct literals.
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
patch := &proto.ProxyPatch{
ForwardingRules: []*proto.ForwardingRule{{
Protocol: proto.RuleProtocol_TCP,
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
}},
}
t.Run("normal_path", func(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
}
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
// nil Components → minimal envelope, no crash. Matches the legacy
// behaviour for missing/unvalidated peers.
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full)
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
assert.Equal(t, "netbird.cloud", full.DnsDomain)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
// AccountSettings deliberately nil
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
}

View File

@@ -0,0 +1,192 @@
package grpc
import (
"context"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ToComponentSyncResponse builds a SyncResponse carrying the compact
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
// field is intentionally left empty — capable peers ignore it and the
// envelope alone is the authoritative wire shape.
//
// PeerConfig is computed once server-side using the receiving peer's own
// account-level network metadata. EnableSSH inside PeerConfig is left at
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
// client even though the server-side PeerConfig reports false.
func ToComponentSyncResponse(
ctx context.Context,
config *nbconfig.Config,
httpConfig *nbconfig.HttpServerConfig,
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
peer *nbpeer.Peer,
turnCredentials *Token,
relayCredentials *Token,
components *types.NetworkMapComponents,
proxyPatch *types.NetworkMap,
dnsName string,
checks []*posture.Checks,
settings *types.Settings,
extraSettings *types.ExtraSettings,
peerGroups []string,
dnsFwdPort int64,
) *proto.SyncResponse {
network := networkOrZero(components)
enableSSH := computeSSHEnabledForPeer(components, peer)
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
useSourcePrefixes := peer.SupportsSourcePrefixes()
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: components,
PeerConfig: peerConfig,
DNSDomain: dnsName,
DNSForwarderPort: dnsFwdPort,
UserIDClaim: userIDClaim,
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
})
resp := &proto.SyncResponse{
PeerConfig: peerConfig,
NetworkMapEnvelope: envelope,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
return resp
}
// networkOrZero returns components.Network or a zero Network — toPeerConfig
// dereferences network.Net which would panic on nil.
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
if c == nil || c.Network == nil {
return &types.Network{}
}
return c.Network
}
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
// patch the components envelope ships alongside. Returns nil when there are
// no fragments to merge — proto3 omits a nil message field, so the receiver
// sees no patch and skips the merge step entirely.
//
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
// delivers fragments pre-expanded — there's no raw component shape to
// derive them from. Components purity isn't violated: proxy data isn't
// policy-graph-derived, it's externally injected post-Calculate, so the
// client merges it on top of its locally-computed NetworkMap.
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
if nm == nil {
return nil
}
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
return nil
}
patch := &proto.ProxyPatch{
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
Routes: networkmap.ToProtocolRoutes(nm.Routes),
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
}
if len(nm.ForwardingRules) > 0 {
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
for _, r := range nm.ForwardingRules {
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
}
}
return patch
}
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
// without this helper the field would be incorrectly false for any peer
// that's the destination of an SSH-enabling policy without having
// peer.SSHEnabled set locally.
//
// Mirrors the two activation paths Calculate() uses:
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
// destinations.
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
// destinations, AND the peer has SSHEnabled set locally — this is the
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
//
// The full SSH AuthorizedUsers map is still produced by the client when it
// runs Calculate() over the envelope.
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
if c == nil || peer == nil {
return false
}
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
// exist in c.Peers, otherwise no rule applies to it.
if _, ok := c.Peers[peer.ID]; !ok {
return false
}
for _, policy := range c.Policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if ruleEnablesSSHForPeer(c, rule, peer) {
return true
}
}
}
return false
}
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
// peer itself has SSH enabled locally.
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
if rule == nil || !rule.Enabled {
return false
}
if !peerInDestinations(c, rule, peer.ID) {
return false
}
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
return true
}
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
}
// peerInDestinations reports whether peerID is in any of rule.Destinations'
// groups (or matches DestinationResource if it's a peer-typed resource —
// for non-peer types Calculate falls through to group lookup, so we mirror
// that exactly to avoid silent divergence).
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
return rule.DestinationResource.ID == peerID
}
for _, groupID := range rule.Destinations {
if c.IsPeerInGroup(peerID, groupID) {
return true
}
}
return false
}

View File

@@ -0,0 +1,184 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
// explicit NetbirdSSH protocol, and the legacy implicit case where a
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
// the destination peer has SSHEnabled=true locally.
func TestComputeSSHEnabledForPeer(t *testing.T) {
const targetPeerID = "target"
const targetGroupID = "g_dst"
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
return &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
Groups: map[string]*types.Group{targetGroupID: group},
Policies: []*types.Policy{{
ID: "p",
Enabled: true,
Rules: []*types.PolicyRule{rule},
}},
}, peer
}
cases := []struct {
name string
peerSSH bool
rule types.PolicyRule
wantEnabled bool
}{
{
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-without-peer-ssh-disabled",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "implicit-tcp-22022-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-all-protocol-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-port-range-covers-22",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "tcp-80-no-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "disabled-rule-skipped",
peerSSH: true,
rule: types.PolicyRule{
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "peer-not-in-destinations",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g_other"}, // target not in this group
},
wantEnabled: false,
},
{
name: "peer-typed-destination-resource-matches",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
},
wantEnabled: true,
},
{
name: "non-peer-destination-resource-falls-through-to-groups",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
Destinations: []string{targetGroupID}, // saved by group fallback
},
wantEnabled: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, peer := mkComponents(&tc.rule, tc.peerSSH)
got := computeSSHEnabledForPeer(c, peer)
assert.Equal(t, tc.wantEnabled, got)
})
}
}
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
// belt-and-suspenders presence guard mirroring Calculate's
// getAllPeersFromGroups invariant.
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
c := &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
Groups: map[string]*types.Group{
"g": {ID: "g", Peers: []string{"missing"}},
},
Policies: []*types.Policy{{
ID: "p", Enabled: true,
Rules: []*types.PolicyRule{{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g"},
}},
}},
}
assert.False(t, computeSSHEnabledForPeer(c, peer),
"missing target peer must short-circuit to false, not consult policies")
}
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
// function entry — Calculate doesn't accept nil either, but the helper is
// exported indirectly via ToComponentSyncResponse and may receive nil
// components on graceful-degrade paths.
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
}

View File

@@ -6,24 +6,22 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/timestamppb"
goproto "google.golang.org/protobuf/proto"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
) )
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -138,8 +136,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes), Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
}, },
Checks: toProtocolChecks(ctx, checks), Checks: toProtocolChecks(ctx, checks),
@@ -152,19 +150,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6) remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
response.RemotePeers = remotePeers response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0 response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6) response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes) firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
@@ -177,7 +175,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
} }
if networkMap.AuthorizedUsers != nil { if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers) hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" { if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim userIDClaim = httpConfig.AuthUserIDClaim
@@ -185,79 +183,36 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim} response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
} }
// settings == nil → field stays nil → "no info in this snapshot", client
// preserves the deadline it already had. settings non-nil → emit either a
// valid deadline or the explicit-zero "disabled" sentinel via
// encodeSessionExpiresAt.
if settings != nil {
response.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
}
return response return response
} }
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) { // encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
userIDToIndex := make(map[string]uint32) // representation used on LoginResponse, SyncResponse and
var hashedUsers [][]byte // ExtendAuthSessionResponse. See the proto comments on those messages.
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers)) //
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
for machineUser, users := range authorizedUsers { // "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
indexes := make([]uint32, 0, len(users)) // its anchor.
for userID := range users { // - deadline non-zero → returns timestamppb.New(deadline): the new absolute
idx, exists := userIDToIndex[userID] // UTC deadline.
if !exists { //
hash, err := sshauth.HashUserID(userID) // Returning nil ("no info, preserve client's anchor") is the caller's job —
if err != nil { // only meaningful on Sync builds where settings were not resolved.
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err) func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
continue if deadline.IsZero() {
} return &timestamppb.Timestamp{}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
} }
return timestamppb.New(deadline)
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
} }
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
@@ -277,204 +232,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
} }
} }
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
// alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
// when includeIPv6 is true.
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
// IPv4Unspecified/0 is always valid, error is impossible.
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
// IPv6Unspecified/0 is always valid, error is impossible.
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if shouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// buildJWTConfig constructs JWT configuration for SSH servers from management server config // buildJWTConfig constructs JWT configuration for SSH servers from management server config
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig { func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
if config == nil || config.AuthAudience == "" { if config == nil || config.AuthAudience == "" {

View File

@@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/shared/management/networkmap"
) )
func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -61,13 +63,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
} }
// First run with config1 // First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2 // Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort)) result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again // Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical // Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) { if !reflect.DeepEqual(result1, result3) {
@@ -99,7 +101,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
} }
}) })
@@ -107,7 +109,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{} cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
} }
}) })
} }
@@ -200,3 +202,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
}) })
} }
} }
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
// "expiry disabled or peer is not SSO-tracked" sentinel.
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
//
// The third state (nil pointer = "no info in this snapshot") is the caller's
// responsibility on the Sync path when settings could not be resolved; the
// helper itself never returns nil.
func TestEncodeSessionExpiresAt(t *testing.T) {
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
got := encodeSessionExpiresAt(time.Time{})
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
assert.Equal(t, int64(0), got.GetSeconds())
assert.Equal(t, int32(0), got.GetNanos())
})
t.Run("non-zero deadline round-trips", func(t *testing.T) {
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
got := encodeSessionExpiresAt(deadline)
assert.NotNil(t, got)
assert.True(t, got.AsTime().Equal(deadline))
})
}

View File

@@ -351,6 +351,7 @@ func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params
SupportsCustomPorts: c.SupportsCustomPorts, SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain, RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec, SupportsCrowdsec: c.SupportsCrowdsec,
Private: c.Private,
} }
} }
@@ -754,6 +755,11 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
InitialSyncComplete: update.InitialSyncComplete, InitialSyncComplete: update.InitialSyncComplete,
} }
} }
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
connUpdate = filterMappingsForProxy(conn, connUpdate)
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
return true
}
resp := s.perProxyMessage(connUpdate, conn.proxyID) resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil { if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
@@ -882,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
} }
} }
// proxyAcceptsMapping returns whether the proxy should receive this mapping. // proxyAcceptsMapping returns whether the proxy can receive this mapping.
// Old proxies that never reported capabilities are skipped for non-TLS L4 // Private mappings require SupportsPrivateService; custom-port L4 mappings
// mappings with a custom listen port, since they don't understand the // require SupportsCustomPorts. Remove operations always pass so proxies can
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false) // clean up.
// are new enough to handle the mapping. TLS uses SNI routing and works on
// any proxy. Delete operations are always sent so proxies can clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true return true
} }
if mapping.GetPrivate() {
caps := conn.capabilities
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
return false
}
}
if mapping.ListenPort == 0 || mapping.Mode == "tls" { if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true return true
} }
@@ -900,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
} }
// filterMappingsForProxy drops mappings the proxy cannot safely receive
// (e.g. private mappings to a proxy without SupportsPrivateService).
// Returns the input unchanged when no filtering is needed.
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
if update == nil || len(update.Mapping) == 0 {
return update
}
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, m := range update.Mapping {
if !proxyAcceptsMapping(conn, m) {
continue
}
kept = append(kept, m)
}
if len(kept) == len(update.Mapping) {
return update
}
return &proto.GetMappingUpdateResponse{
Mapping: kept,
InitialSyncComplete: update.InitialSyncComplete,
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for // perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is // create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal. // used unchanged because proxies do not need to authenticate for removal.
@@ -961,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
authenticated, userId, method := s.authenticateRequest(ctx, req, service) authenticated, userId, method := s.authenticateRequest(ctx, req, service)
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method) // Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
// secrets and have no user-level group context, so groups stay nil. Email
// is also empty — these schemes don't resolve a user record at sign time.
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1050,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
} }
} }
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
if !authenticated || service.SessionPrivateKey == "" { if !authenticated || service.SessionPrivateKey == "" {
return "", nil return "", nil
} }
@@ -1058,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
token, err := sessionkey.SignToken( token, err := sessionkey.SignToken(
service.SessionPrivateKey, service.SessionPrivateKey,
userId, userId,
userEmail,
service.Domain, service.Domain,
method, method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry, proxyauth.DefaultSessionExpiry,
) )
if err != nil { if err != nil {
@@ -1070,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
return token, nil return token, nil
} }
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
// into parallel id and name slices. ids[i] and names[i] always pair to
// the same group. nil entries (orphan ids the manager couldn't resolve)
// are skipped so the consumer can rely on positional pairing.
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
if len(groups) == 0 {
return nil, nil
}
ids = make([]string, 0, len(groups))
names = make([]string, 0, len(groups))
for _, g := range groups {
if g == nil {
continue
}
ids = append(ids, g.ID)
names = append(names, g.Name)
}
return ids, names
}
// SendStatusUpdate handles status updates from proxy clients. // SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
@@ -1334,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return verifier, redirectURL, nil return verifier, redirectURL, nil
} }
// GenerateSessionToken creates a signed session JWT for the given domain and user. // GenerateSessionToken creates a signed session JWT for the given domain and
// user. The user's group memberships are embedded in the token so policy-aware
// middlewares on the proxy can authorise without an extra management round-trip.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
service, err := s.getServiceByDomain(ctx, domain) service, err := s.getServiceByDomain(ctx, domain)
if err != nil { if err != nil {
@@ -1345,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
return "", fmt.Errorf("no session key configured for domain: %s", domain) return "", fmt.Errorf("no session key configured for domain: %s", domain)
} }
var (
email string
groupIDs []string
groupNames []string
)
if s.usersManager != nil {
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
if uerr != nil {
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
} else if user != nil {
email = user.Email
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
}
}
return sessionkey.SignToken( return sessionkey.SignToken(
service.SessionPrivateKey, service.SessionPrivateKey,
userID, userID,
email,
domain, domain,
method, method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry, proxyauth.DefaultSessionExpiry,
) )
} }
@@ -1453,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil }, nil
} }
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes) userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"domain": domain, "domain": domain,
@@ -1466,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil }, nil
} }
user, err := s.usersManager.GetUser(ctx, userID) user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"domain": domain, "domain": domain,
@@ -1500,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"user_id": userID, "user_id": userID,
"error": err.Error(), "error": err.Error(),
}).Debug("ValidateSession: access denied") }).Debug("ValidateSession: access denied")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
//nolint:nilerr //nolint:nilerr
return &proto.ValidateSessionResponse{ return &proto.ValidateSessionResponse{
Valid: false, Valid: false,
UserId: user.Id, UserId: user.Id,
UserEmail: user.Email, UserEmail: user.Email,
DeniedReason: "not_in_group", DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil }, nil
} }
@@ -1515,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"email": user.Email, "email": user.Email,
}).Debug("ValidateSession: access granted") }).Debug("ValidateSession: access granted")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
return &proto.ValidateSessionResponse{ return &proto.ValidateSessionResponse{
Valid: true, Valid: true,
UserId: user.Id, UserId: user.Id,
UserEmail: user.Email, UserEmail: user.Email,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil }, nil
} }
@@ -1551,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
} }
func ptr[T any](v T) *T { return &v } func ptr[T any](v T) *T { return &v }
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
// checks the peer's group membership against the service's access groups.
// Peers without a user (machine agents, automation workloads) are first-class
// callers; authorisation runs off peer-group memberships rather than the
// optional owning user's auto-groups. On success a session JWT is minted so
// the proxy can install a cookie and skip subsequent management round-trips.
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
domain := req.GetDomain()
tunnelIPStr := req.GetTunnelIp()
if domain == "" || tunnelIPStr == "" {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "missing domain or tunnel_ip",
}, nil
}
tunnelIP := net.ParseIP(tunnelIPStr)
if tunnelIP == nil {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "invalid_tunnel_ip",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
// validate and mint session cookies for their own account's domains.
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
if err != nil || peer == nil {
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
UserId: principalID,
UserEmail: displayIdentity,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
if err != nil {
return nil, err
}
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
"peer_id": peer.ID,
"principal_id": principalID,
}).Debug("ValidateTunnelPeer: access granted")
return &proto.ValidateTunnelPeerResponse{
Valid: true,
UserId: principalID,
UserEmail: displayIdentity,
SessionToken: token,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security
// boundary and must not trust upstream state). Bearer-auth services authorise
// against DistributionGroups when populated. Non-private non-bearer services
// are open.
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
if service.Private {
if len(service.AccessGroups) == 0 {
return fmt.Errorf("private service has no access groups")
}
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
}
return nil
}
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
// else a non-nil error.
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
if len(allowedGroups) == 0 {
return fmt.Errorf("no allowed groups configured")
}
allowed := make(map[string]struct{}, len(allowedGroups))
for _, g := range allowedGroups {
allowed[g] = struct{}{}
}
for _, g := range peerGroupIDs {
if _, ok := allowed[g]; ok {
return nil
}
}
return fmt.Errorf("peer not in allowed groups")
}

View File

@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
return user, nil return user, nil
} }
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.GetUser(ctx, userID)
if err != nil {
return nil, nil, err
}
return user, nil, nil
}
func TestValidateUserGroupAccess(t *testing.T) { func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
}) })
} }
} }
func TestCheckPeerGroupAccess(t *testing.T) {
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: nil}
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
require.Error(t, err)
assert.Contains(t, err.Error(), "no access groups")
})
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
})
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
})
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-allowed"},
},
},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("non-private non-bearer is open", func(t *testing.T) {
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
})
}

View File

@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
return nil return nil
} }
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) { if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period) // Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message") return status.Errorf(codes.Internal, "failed sending update message")
} }
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
return nil return nil
} }
@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}, nil }, nil
} }
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
// stays up; no network map sync is performed. The new deadline is returned
// in ExtendAuthSessionResponse.SessionExpiresAt.
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
extendReq := &proto.ExtendAuthSessionRequest{}
peerKey, err := s.parseRequest(ctx, req, extendReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
}
jwt := extendReq.GetJwtToken()
if jwt == "" {
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
}
var userID string
const attempts = 3
for i := 0; i < attempts; i++ {
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
if err == nil {
break
}
if i == attempts-1 {
break
}
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
select {
case <-time.After(200 * time.Millisecond):
case <-ctx.Done():
return nil, ctx.Err()
}
}
if err != nil {
return nil, err
}
if userID == "" {
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
}
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
if err != nil {
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
return nil, mapError(ctx, err)
}
// Success path normally returns a non-zero deadline. A defensive zero
// would still encode as the explicit "disabled" sentinel rather than nil,
// so the client clears any stale anchor instead of preserving it.
resp := &proto.ExtendAuthSessionResponse{
SessionExpiresAt: encodeSessionExpiresAt(deadline),
}
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed processing request")
}
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed encrypting response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encrypted,
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token var relayToken *Token
var err error var err error
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
Checks: toProtocolChecks(ctx, postureChecks), Checks: toProtocolChecks(ctx, postureChecks),
} }
// settings is always non-nil here, so we never emit nil — encoder returns
// either a valid deadline or the explicit-zero "disabled" sentinel.
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
return loginResp, nil return loginResp, nil
} }
@@ -932,7 +1012,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "failed to get peer groups %s", err) return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
} }
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) dnsName := s.networkMapController.GetDNSDomain(settings)
var plainResp *proto.SyncResponse
if s.networkMapController.PeerNeedsComponents(peer) {
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
// computed and recompute the raw components instead. This wastes one
// Calculate() call per initial-sync — the component-based wire
// format is what the peer actually consumes. The streaming path
// (network_map.Controller.UpdateAccountPeers) skips this duplication
// because it dispatches by capability before computing.
//
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
// interfaces to return PeerNetworkMapResult so the initial-sync path
// stops doing duplicate work. Deferred until the client-side
// decoder lands and there's a real deployment of capability=3 peers
// worth optimizing for.
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
}
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
} else {
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
}
key, err := s.secretsManager.GetWGKey() key, err := s.secretsManager.GetWGKey()
if err != nil { if err != nil {

View File

@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string { func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper() t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour) token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
require.NoError(t, err) require.NoError(t, err)
return token return token
} }
@@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
assert.True(t, resp.Valid, "User should be allowed access") assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId) assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason) assert.Empty(t, resp.DeniedReason)
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
} }
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) { func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
@@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
assert.False(t, resp.Valid, "User not in group should be denied") assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason) assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId) assert.Equal(t, "nonGroupUserId", resp.UserId)
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
} }
func TestValidateSession_UserInDifferentAccount(t *testing.T) { func TestValidateSession_UserInDifferentAccount(t *testing.T) {

View File

@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain || oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion || oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
// Session deadline is derived from LastLogin + PeerLoginExpiration
// on every Login/Sync response. Without a fan-out push, connected
// peers keep the deadline they received at login time and only see
// the new value after the next unrelated NetworkMap change. Add
// these two fields to the trigger list so admin-side expiry tweaks
// (e.g. shortening from 24h to 1h) reach every connected peer
// within seconds, which is what the proactive-warning feature
// relies on (see client/internal/auth/sessionwatch).
updateAccountPeers = true updateAccountPeers = true
} }
@@ -1621,6 +1631,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
for _, g := range newGroupsToCreate {
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
if err != nil {
return fmt.Errorf("error allocating group seq id: %w", err)
}
g.AccountSeqID = seq
}
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil { if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }

View File

@@ -109,6 +109,7 @@ type Manager interface {
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)

View File

@@ -1304,6 +1304,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
} }
// ExtendPeerSession mocks base method.
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
ret0, _ := ret[0].(time.Time)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
}
// MarkPeerConnected mocks base method. // MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -3036,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
var newJWTGroup *types.Group
for _, g := range groups {
if g.Name == "group3" {
newJWTGroup = g
break
}
}
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
}) })
t.Run("remove all JWT groups when list is empty", func(t *testing.T) { t.Run("remove all JWT groups when list is empty", func(t *testing.T) {

View File

@@ -240,6 +240,10 @@ const (
AccountLocalMfaEnabled Activity = 123 AccountLocalMfaEnabled Activity = 123
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users // AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
AccountLocalMfaDisabled Activity = 124 AccountLocalMfaDisabled Activity = 124
// UserExtendedPeerSession indicates that a user refreshed their peer's
// SSO session deadline via ExtendAuthSession without re-establishing the
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
UserExtendedPeerSession Activity = 125
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"}, AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"}, AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
DomainAdded: {"Domain added", "domain.add"}, DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"}, DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"}, DomainValidated: {"Domain validated", "domain.validate"},

View File

@@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err return err
} }
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
if err != nil {
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
}
newGroup.AccountSeqID = seq
if err := transaction.CreateGroup(ctx, newGroup); err != nil { if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err) return status.Errorf(status.Internal, "failed to create group: %v", err)
} }
@@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
newGroup.AccountSeqID = oldGroup.AccountSeqID
if err = transaction.UpdateGroup(ctx, newGroup); err != nil { if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err return err
} }
@@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
newGroup.AccountID = accountID newGroup.AccountID = accountID
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
if err != nil {
return err
}
newGroup.AccountSeqID = seq
if err = transaction.CreateGroup(ctx, newGroup); err != nil { if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err return err
} }
@@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
newGroup.AccountID = accountID newGroup.AccountID = accountID
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return err
}
newGroup.AccountSeqID = oldGroup.AccountSeqID
if err := transaction.UpdateGroup(ctx, newGroup); err != nil { if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
return err return err
} }

View File

@@ -15,15 +15,13 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp" idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -32,12 +30,10 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy" "github.com/netbirdio/netbird/management/server/http/handlers/proxy"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups" nbgroups "github.com/netbirdio/netbird/management/server/groups"
@@ -56,17 +52,14 @@ import (
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
nbinstance "github.com/netbirdio/netbird/management/server/instance" nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbnetworks "github.com/netbirdio/netbird/management/server/networks" nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
const apiPrefix = "/api"
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints // Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil { if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetUserFromUserAuth, accountManager.GetUserFromUserAuth,
rateLimiter, rateLimiter,
appMetrics.GetMeter(), appMetrics.GetMeter(),
isValidChildAccount,
) )
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware() metricsMiddleware := appMetrics.HTTPMiddleware()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler) router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil { instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
// Check if embedded IdP is enabled for instance manager
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err) return nil, fmt.Errorf("failed to create instance manager: %w", err)
} }
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
oauthHandler.RegisterEndpoints(router) oauthHandler.RegisterEndpoints(router)
} }
// Mount embedded IdP handler at /oauth2 path if configured return router, nil
if embeddedIdpEnabled {
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
}
return rootRouter, nil
} }

View File

@@ -11,8 +11,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth" serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
authManager serverauth.Manager authManager serverauth.Manager
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
syncUserJWTGroups SyncUserJWTGroupsFunc syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter rateLimiter *APIRateLimiter
patUsageTracker *PATUsageTracker patUsageTracker *PATUsageTracker
isValidChildAccount IsValidChildAccountFunc
} }
// NewAuthMiddleware instance constructor // NewAuthMiddleware instance constructor
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
getUserFromUserAuth GetUserFromUserAuthFunc, getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiter *APIRateLimiter, rateLimiter *APIRateLimiter,
meter metric.Meter, meter metric.Meter,
isValidChildAccount IsValidChildAccountFunc,
) *AuthMiddleware { ) *AuthMiddleware {
var patUsageTracker *PATUsageTracker var patUsageTracker *PATUsageTracker
if meter != nil { if meter != nil {
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
getUserFromUserAuth: getUserFromUserAuth, getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter, rateLimiter: rateLimiter,
patUsageTracker: patUsageTracker, patUsageTracker: patUsageTracker,
isValidChildAccount: isValidChildAccount,
} }
} }
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
} }
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) { if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0] userAuth.AccountId = impersonate[0]
userAuth.IsChild = true userAuth.IsChild = true
} }
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
} }
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) { if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0] userAuth.AccountId = impersonate[0]
userAuth.IsChild = true userAuth.IsChild = true
} }

View File

@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}, },
disabledLimiter, disabledLimiter,
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handlerToTest := authMiddleware.Handler(nextHandler) handlerToTest := authMiddleware.Handler(nextHandler)
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
}, },
NewAPIRateLimiter(rateLimitConfig), NewAPIRateLimiter(rateLimitConfig),
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
}, },
disabledLimiter, disabledLimiter,
nil, nil,
func(_ context.Context, _, _, _ string) bool { return false },
) )
for _, tc := range tt { for _, tc := range tt {

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop" "go.opentelemetry.io/otel/metric/noop"
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }

View File

@@ -0,0 +1,62 @@
package validator
import (
"context"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type IntegratedValidatorImpl struct{}
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
return &IntegratedValidatorImpl{}, nil
}
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
return nil
}
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
return peer.Copy()
}
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, p := range peers {
validatedPeers[p.ID] = struct{}{}
}
return validatedPeers, nil
}
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
return make(map[string]string), nil
}
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
return nil
}
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
}
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
}
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
return flowResponse
}

View File

@@ -17,6 +17,7 @@ import (
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version" nbversion "github.com/netbirdio/netbird/version"
) )
@@ -53,6 +54,7 @@ type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() types.Engine GetStoreEngine() types.Engine
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
} }
// ConnManager peer connection manager that holds state for current active connections // ConnManager peer connection manager that holds state for current active connections
@@ -223,6 +225,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
servicesAuthPassword int servicesAuthPassword int
servicesAuthPin int servicesAuthPin int
servicesAuthOIDC int servicesAuthOIDC int
// Private-service signals — track adoption of NetBird-only mode
// (services backed by an embedded proxy peer + access groups).
servicesPrivate int
servicesPrivateWithGroups int
servicesPrivateAccessGroupsSum int
servicesWithDirectUpstream int
) )
start := time.Now() start := time.Now()
metricsProperties := make(properties) metricsProperties := make(properties)
@@ -380,9 +388,31 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
servicesAuthOIDC++ servicesAuthOIDC++
} }
if service.Private {
servicesPrivate++
if len(service.AccessGroups) > 0 {
servicesPrivateWithGroups++
}
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
}
for _, target := range service.Targets {
if target.Options.DirectUpstream {
servicesWithDirectUpstream++
break
}
}
} }
} }
// Proxy / BYOP cluster signals come from the proxies table aggregated
// across all accounts in a single store query; nil on FileStore.
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
if err != nil {
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
}
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions) minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
metricsProperties["uptime"] = uptime metricsProperties["uptime"] = uptime
metricsProperties["accounts"] = accounts metricsProperties["accounts"] = accounts
@@ -430,6 +460,15 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["services_auth_password"] = servicesAuthPassword metricsProperties["services_auth_password"] = servicesAuthPassword
metricsProperties["services_auth_pin"] = servicesAuthPin metricsProperties["services_auth_pin"] = servicesAuthPin
metricsProperties["services_auth_oidc"] = servicesAuthOIDC metricsProperties["services_auth_oidc"] = servicesAuthOIDC
metricsProperties["services_private"] = servicesPrivate
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
metricsProperties["proxies"] = proxyMetrics.Proxies
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
metricsProperties["custom_domains"] = customDomains metricsProperties["custom_domains"] = customDomains
metricsProperties["custom_domains_validated"] = customDomainsValidated metricsProperties["custom_domains_validated"] = customDomainsValidated

View File

@@ -12,6 +12,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -123,7 +124,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
Enabled: true, Enabled: true,
Targets: []*rpservice.Target{ Targets: []*rpservice.Target{
{TargetType: "peer"}, {TargetType: "peer"},
{TargetType: "host"}, {TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}},
}, },
Auth: rpservice.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true}, PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
@@ -141,6 +142,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}, },
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)}, Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
}, },
{
ID: "svc3-private",
Enabled: true,
Private: true,
AccessGroups: []string{"grp-eng", "grp-ops"},
Targets: []*rpservice.Target{
{TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
},
}, },
}, },
{ {
@@ -254,6 +265,18 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e
return 3, 2, nil return 3, 2, nil
} }
// GetProxyMetrics returns canned proxy/cluster counts so the
// generateProperties test can assert the BYOP signals end-to-end.
func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) {
return store.ProxyMetrics{
Clusters: 3,
ClustersBYOP: 1,
ClustersPrivate: 1,
Proxies: 4,
ProxiesConnected: 2,
}, nil
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
func TestGenerateProperties(t *testing.T) { func TestGenerateProperties(t *testing.T) {
ds := mockDatasource{} ds := mockDatasource{}
@@ -393,17 +416,17 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"]) t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
} }
if properties["services"] != 2 { if properties["services"] != 3 {
t.Errorf("expected 2 services, got %v", properties["services"]) t.Errorf("expected 3 services, got %v", properties["services"])
} }
if properties["services_enabled"] != 1 { if properties["services_enabled"] != 2 {
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"]) t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"])
} }
if properties["services_targets"] != 3 { if properties["services_targets"] != 4 {
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"]) t.Errorf("expected 4 services_targets, got %v", properties["services_targets"])
} }
if properties["services_status_active"] != 1 { if properties["services_status_active"] != 2 {
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"]) t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"])
} }
if properties["services_status_pending"] != 1 { if properties["services_status_pending"] != 1 {
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"]) t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
@@ -420,6 +443,9 @@ func TestGenerateProperties(t *testing.T) {
if properties["services_target_type_domain"] != 1 { if properties["services_target_type_domain"] != 1 {
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"]) t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
} }
if properties["services_target_type_cluster"] != 1 {
t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"])
}
if properties["services_auth_password"] != 1 { if properties["services_auth_password"] != 1 {
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"]) t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
} }
@@ -429,6 +455,33 @@ func TestGenerateProperties(t *testing.T) {
if properties["services_auth_pin"] != 0 { if properties["services_auth_pin"] != 0 {
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"]) t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
} }
if properties["services_private"] != 1 {
t.Errorf("expected 1 services_private, got %v", properties["services_private"])
}
if properties["services_private_with_access_groups"] != 1 {
t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"])
}
if properties["services_private_access_groups_sum"] != 2 {
t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"])
}
if properties["services_with_direct_upstream"] != 2 {
t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"])
}
if properties["proxy_clusters"] != int64(3) {
t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"])
}
if properties["proxy_clusters_byop"] != int64(1) {
t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"])
}
if properties["proxy_clusters_private"] != int64(1) {
t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"])
}
if properties["proxies"] != int64(4) {
t.Errorf("expected 4 proxies, got %v", properties["proxies"])
}
if properties["proxies_connected"] != int64(2) {
t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"])
}
if properties["custom_domains"] != int64(3) { if properties["custom_domains"] != int64(3) {
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"]) t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
} }

View File

@@ -0,0 +1,156 @@
package migration
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/types"
)
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
// with the next free id per account. Idempotent: safe to re-run; both steps
// no-op once everything is consistent.
//
// Implemented as two table-wide SQL statements with window functions, one
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
// well under a second instead of the per-account-loop ~2 minutes.
//
// orderColumn is the column to use when assigning the deterministic ordering
// (typically the primary-key string id).
func BackfillAccountSeqIDs[T any](
ctx context.Context,
db *gorm.DB,
entity types.AccountSeqEntity,
orderColumn string,
) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
return nil
}
stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("parse model: %w", err)
}
table := quoteIdent(db, stmt.Schema.Table)
orderCol := quoteIdent(db, orderColumn)
return db.Transaction(func(tx *gorm.DB) error {
var pending int64
if err := tx.Raw(
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
).Scan(&pending).Error; err != nil {
return fmt.Errorf("count pending on %s: %w", table, err)
}
if pending > 0 {
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
if err := backfillRankSQL(tx, table, orderCol); err != nil {
return fmt.Errorf("rank %s: %w", table, err)
}
}
if err := seedCountersSQL(tx, table, entity); err != nil {
return fmt.Errorf("seed counters for %s: %w", entity, err)
}
return nil
})
}
func quoteIdent(db *gorm.DB, name string) string {
switch db.Dialector.Name() {
case "mysql":
return "`" + name + "`"
case "postgres":
return `"` + name + `"`
default:
return name
}
}
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
dialect := db.Dialector.Name()
var sql string
switch dialect {
case "postgres", "sqlite":
sql = fmt.Sprintf(`
WITH max_seq AS (
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
FROM %s
GROUP BY account_id
),
ranked AS (
SELECT p.id,
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
FROM %s p
JOIN max_seq m ON p.account_id = m.account_id
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
)
UPDATE %s SET account_seq_id = ranked.new_seq
FROM ranked
WHERE %s.id = ranked.id
`, table, orderCol, table, table, table)
case "mysql":
sql = fmt.Sprintf(`
UPDATE %s p
JOIN (
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
FROM %s
GROUP BY account_id
) m ON p.account_id = m.account_id
JOIN (
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
FROM %s
WHERE account_seq_id IS NULL OR account_seq_id = 0
) r ON p.id = r.id
SET p.account_seq_id = m.max_seq + r.rn
`, table, table, orderCol, table)
default:
return fmt.Errorf("unsupported dialect: %s", dialect)
}
return db.Exec(sql).Error
}
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
dialect := db.Dialector.Name()
var sql string
switch dialect {
case "postgres":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
`, table)
case "sqlite":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
`, table)
case "mysql":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
`, table)
default:
return fmt.Errorf("unsupported dialect: %s", dialect)
}
return db.Exec(sql, string(entity)).Error
}

View File

@@ -98,6 +98,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
@@ -860,6 +861,14 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
} }
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
func (am *MockAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
if am.ExtendPeerSessionFunc != nil {
return am.ExtendPeerSessionFunc(ctx, peerPubKey, userID)
}
return time.Time{}, status.Errorf(codes.Unimplemented, "method ExtendPeerSession is not implemented")
}
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {

View File

@@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err return err
} }
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
if err != nil {
return err
}
newNSGroup.AccountSeqID = seq
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil { if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err return err
} }
@@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err return err
} }
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil { if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err return err
} }

View File

@@ -71,9 +71,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
network.ID = xid.New().String() network.ID = xid.New().String()
err = m.store.SaveNetwork(ctx, network) err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
if err != nil {
return fmt.Errorf("failed to allocate network seq id: %w", err)
}
network.AccountSeqID = seq
if err := transaction.SaveNetwork(ctx, network); err != nil {
return fmt.Errorf("failed to save network: %w", err)
}
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err) return nil, err
} }
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta()) m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
@@ -102,14 +113,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
network.AccountSeqID = existing.AccountSeqID
if err := transaction.SaveNetwork(ctx, network); err != nil {
return fmt.Errorf("failed to save network: %w", err)
}
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err) return nil, err
} }
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
return network, m.store.SaveNetwork(ctx, network) return network, nil
} }
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {

Some files were not shown because too many files have changed in this diff Show More