mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-30 04:29:57 +00:00
Compare commits
32 Commits
worktree-a
...
nmap/compo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eab0826b4e | ||
|
|
7048b87931 | ||
|
|
174dc24867 | ||
|
|
596952265d | ||
|
|
7ea5e37dd4 | ||
|
|
9d7ef9b255 | ||
|
|
944a258459 | ||
|
|
21cfec93d4 | ||
|
|
1f9a829f2c | ||
|
|
98818e3095 | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 | ||
|
|
5d5c2d9f95 | ||
|
|
d542c60e21 | ||
|
|
4983b5cf17 | ||
|
|
b3b0feb3b8 | ||
|
|
7aebdd69dd | ||
|
|
13e41e432c | ||
|
|
efa6a3f502 | ||
|
|
5fbcdeceac | ||
|
|
3a1bbeba90 | ||
|
|
728057ef15 | ||
|
|
582cd70086 | ||
|
|
9bbbafaf69 | ||
|
|
672b057aa0 | ||
|
|
b9a0186200 | ||
|
|
9083bdb977 | ||
|
|
b194af48b8 | ||
|
|
4543780ef0 | ||
|
|
2de0283971 |
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
|||||||
display_name: Linux
|
display_name: Linux
|
||||||
name: ${{ matrix.display_name }}
|
name: ${{ matrix.display_name }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
timeout-minutes: 15
|
timeout-minutes: 25
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -58,4 +58,4 @@ jobs:
|
|||||||
skip-cache: true
|
skip-cache: true
|
||||||
skip-save-cache: true
|
skip-save-cache: true
|
||||||
cache-invalidation-interval: 0
|
cache-invalidation-interval: 0
|
||||||
args: --timeout=12m
|
args: --timeout=20m
|
||||||
|
|||||||
66
.github/workflows/proto-version-check.yml
vendored
66
.github/workflows/proto-version-check.yml
vendored
@@ -20,34 +20,66 @@ jobs:
|
|||||||
per_page: 100,
|
per_page: 100,
|
||||||
});
|
});
|
||||||
|
|
||||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
const modifiedPbFiles = files.filter(
|
||||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||||
if (missingPatch.length > 0) {
|
);
|
||||||
core.setFailed(
|
if (modifiedPbFiles.length === 0) {
|
||||||
`Cannot inspect patch data for:\n` +
|
console.log('No modified .pb.go files to check');
|
||||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
|
||||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
|
||||||
);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
|
||||||
const violations = [];
|
|
||||||
|
|
||||||
for (const file of pbFiles) {
|
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||||
const changed = file.patch
|
const baseSha = context.payload.pull_request.base.sha;
|
||||||
.split('\n')
|
const headSha = context.payload.pull_request.head.sha;
|
||||||
.filter(line => versionPattern.test(line));
|
|
||||||
if (changed.length > 0) {
|
async function getVersionHeader(path, ref) {
|
||||||
|
try {
|
||||||
|
const res = await github.rest.repos.getContent({
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
path,
|
||||||
|
ref,
|
||||||
|
});
|
||||||
|
if (!res.data.content) {
|
||||||
|
return { ok: false, reason: 'no inline content (file too large)' };
|
||||||
|
}
|
||||||
|
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
|
||||||
|
const lines = content
|
||||||
|
.split('\n')
|
||||||
|
.slice(0, 20)
|
||||||
|
.filter(line => versionPattern.test(line));
|
||||||
|
return { ok: true, lines };
|
||||||
|
} catch (e) {
|
||||||
|
return { ok: false, reason: e.message };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const violations = [];
|
||||||
|
for (const file of modifiedPbFiles) {
|
||||||
|
const [base, head] = await Promise.all([
|
||||||
|
getVersionHeader(file.filename, baseSha),
|
||||||
|
getVersionHeader(file.filename, headSha),
|
||||||
|
]);
|
||||||
|
if (!base.ok || !head.ok) {
|
||||||
|
core.warning(
|
||||||
|
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||||
violations.push({
|
violations.push({
|
||||||
file: file.filename,
|
file: file.filename,
|
||||||
lines: changed,
|
base: base.lines,
|
||||||
|
head: head.lines,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (violations.length > 0) {
|
if (violations.length > 0) {
|
||||||
const details = violations.map(v =>
|
const details = violations.map(v =>
|
||||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
`${v.file}:\n` +
|
||||||
|
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
|
||||||
|
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
|
||||||
).join('\n\n');
|
).join('\n\n');
|
||||||
|
|
||||||
core.setFailed(
|
core.setFailed(
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@@ -84,6 +85,12 @@ type Options struct {
|
|||||||
DisableIPv6 bool
|
DisableIPv6 bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
// BlockLANAccess blocks the embedded peer from reaching the host's
|
||||||
|
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
|
||||||
|
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
|
||||||
|
// when the embedded client must never act as a stepping stone into
|
||||||
|
// the host's local network (e.g. the proxy's overlay peer).
|
||||||
|
BlockLANAccess bool
|
||||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||||
WireguardPort *int
|
WireguardPort *int
|
||||||
// MTU is the MTU for the tunnel interface.
|
// MTU is the MTU for the tunnel interface.
|
||||||
@@ -94,6 +101,26 @@ type Options struct {
|
|||||||
MTU *uint16
|
MTU *uint16
|
||||||
// DNSLabels defines additional DNS labels configured in the peer.
|
// DNSLabels defines additional DNS labels configured in the peer.
|
||||||
DNSLabels []string
|
DNSLabels []string
|
||||||
|
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||||
|
Performance Performance
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||||
|
//
|
||||||
|
// These settings are process-global: any non-nil field also becomes the
|
||||||
|
// default for Clients constructed by later embed.New calls in the same
|
||||||
|
// process. Nil fields are ignored.
|
||||||
|
type Performance struct {
|
||||||
|
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||||
|
// leaves the pool unbounded. Lower values trade throughput for a
|
||||||
|
// tighter memory ceiling. May also be changed on a running Client via
|
||||||
|
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||||
|
PreallocatedBuffersPerPool *uint32
|
||||||
|
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||||
|
// writes per syscall, which also bounds eager buffer allocation per
|
||||||
|
// worker. Zero uses the platform default. Applied at construction
|
||||||
|
// only; ignored by Client.SetPerformance.
|
||||||
|
MaxBatchSize *uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
DisableIPv6: &opts.DisableIPv6,
|
DisableIPv6: &opts.DisableIPv6,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
BlockLANAccess: &opts.BlockLANAccess,
|
||||||
WireguardPort: opts.WireguardPort,
|
WireguardPort: opts.WireguardPort,
|
||||||
MTU: opts.MTU,
|
MTU: opts.MTU,
|
||||||
DNSLabels: parsedLabels,
|
DNSLabels: parsedLabels,
|
||||||
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
|||||||
config.PrivateKey = opts.PrivateKey
|
config.PrivateKey = opts.PrivateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||||
|
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||||
|
}
|
||||||
|
if opts.Performance.MaxBatchSize != nil {
|
||||||
|
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||||
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
deviceName: opts.DeviceName,
|
deviceName: opts.DeviceName,
|
||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||||
|
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||||
|
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||||
|
// roster — callers should treat that as "unknown peer".
|
||||||
|
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||||
|
if !ip.IsValid() || c.recorder == nil {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||||
|
if !found {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return state.PubKey, state.FQDN, true
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
|||||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||||
|
// takes effect, and only when it was nonzero at construction;
|
||||||
|
// MaxBatchSize is construction-only and returns an error if set here.
|
||||||
|
//
|
||||||
|
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||||
|
// running yet.
|
||||||
|
func (c *Client) SetPerformance(t Performance) error {
|
||||||
|
if t.MaxBatchSize != nil {
|
||||||
|
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||||
|
}
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return engine.SetPerformance(internal.Performance{
|
||||||
|
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// StartCapture begins capturing packets on this client's tunnel device.
|
// StartCapture begins capturing packets on this client's tunnel device.
|
||||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||||
|
|||||||
@@ -1,199 +0,0 @@
|
|||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func iptRefcountIfaceV4() *iFaceMock {
|
|
||||||
return &iFaceMock{
|
|
||||||
NameFunc: func() string { return "wt-refcount" },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("10.20.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func iptRefcountIfaceDual() *iFaceMock {
|
|
||||||
return &iFaceMock{
|
|
||||||
NameFunc: func() string { return "wt-refcount" },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("10.20.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
|
||||||
IPv6: netip.MustParseAddr("fd00::1"),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIptRefcountManager(t *testing.T, dual bool) *Manager {
|
|
||||||
t.Helper()
|
|
||||||
var ifMock *iFaceMock
|
|
||||||
if dual {
|
|
||||||
ifMock = iptRefcountIfaceDual()
|
|
||||||
} else {
|
|
||||||
ifMock = iptRefcountIfaceV4()
|
|
||||||
}
|
|
||||||
m, err := Create(ifMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err, "create manager")
|
|
||||||
require.NoError(t, m.Init(nil), "init manager")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, m.Close(nil), "close manager")
|
|
||||||
})
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func iptDnatV4(port uint16) fw.ForwardRule {
|
|
||||||
return fw.ForwardRule{
|
|
||||||
Protocol: fw.ProtocolTCP,
|
|
||||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
|
||||||
TranslatedAddress: netip.MustParseAddr("10.20.0.2"),
|
|
||||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func iptDnatV6(port uint16) fw.ForwardRule {
|
|
||||||
return fw.ForwardRule{
|
|
||||||
Protocol: fw.ProtocolTCP,
|
|
||||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
|
||||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
|
||||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIptablesDNAT_RefcountBalancedV4 covers a Balanced Add/Delete pair on v4.
|
|
||||||
func TestIptablesDNAT_RefcountBalancedV4(t *testing.T) {
|
|
||||||
m := newIptRefcountManager(t, false)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(iptDnatV4(7081))
|
|
||||||
require.NoError(t, err, "add v4 dnat 1")
|
|
||||||
v4, v6 := state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "v4 refcount after first add")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
|
|
||||||
r2, err := m.AddDNATRule(iptDnatV4(7082))
|
|
||||||
require.NoError(t, err, "add v4 dnat 2")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 2, v4, "v4 refcount after second add")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1))
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "v4 refcount after first delete")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r2))
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount after second delete")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIptablesDNAT_RefcountBalancedV6 checks the v6 path increments v6 only and
|
|
||||||
// decrements back to zero.
|
|
||||||
func TestIptablesDNAT_RefcountBalancedV6(t *testing.T) {
|
|
||||||
m := newIptRefcountManager(t, true)
|
|
||||||
require.NotNil(t, m.router6, "v6 router")
|
|
||||||
require.Same(t, m.router.ipFwdState, m.router6.ipFwdState, "shared state")
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(iptDnatV6(9081))
|
|
||||||
require.NoError(t, err, "add v6 dnat 1")
|
|
||||||
v4, v6 := state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 1, v6, "v6 refcount after first add")
|
|
||||||
|
|
||||||
r2, err := m.AddDNATRule(iptDnatV6(9082))
|
|
||||||
require.NoError(t, err, "add v6 dnat 2")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
|
||||||
require.Equal(t, 2, v6, "v6 refcount after second add")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1))
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
|
||||||
require.Equal(t, 1, v6, "v6 refcount after first delete")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r2))
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount after second delete")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIptablesDNAT_DuplicateAddNoLeak verifies the duplicate-rule path returns
|
|
||||||
// without bumping the refcount.
|
|
||||||
func TestIptablesDNAT_DuplicateAddNoLeak(t *testing.T) {
|
|
||||||
m := newIptRefcountManager(t, true)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
rule := iptDnatV4(7083)
|
|
||||||
r1, err := m.AddDNATRule(rule)
|
|
||||||
require.NoError(t, err)
|
|
||||||
v4, _ := state.Counts()
|
|
||||||
require.Equal(t, 1, v4)
|
|
||||||
|
|
||||||
_, err = m.AddDNATRule(rule)
|
|
||||||
require.NoError(t, err, "duplicate add")
|
|
||||||
v4, _ = state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "duplicate add must not increment")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1))
|
|
||||||
v4, _ = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "single delete must drop to zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIptablesDNAT_DeleteMissingNoUnderflow verifies Delete on an unknown rule
|
|
||||||
// neither errors nor releases the refcount.
|
|
||||||
func TestIptablesDNAT_DeleteMissingNoUnderflow(t *testing.T) {
|
|
||||||
m := newIptRefcountManager(t, true)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
phantom := iptDnatV4(7099)
|
|
||||||
require.NoError(t, m.DeleteDNATRule(&phantom), "delete missing v4")
|
|
||||||
v4, v6 := state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 0, v6)
|
|
||||||
|
|
||||||
phantom6 := iptDnatV6(9099)
|
|
||||||
require.NoError(t, m.DeleteDNATRule(&phantom6), "delete missing v6")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 0, v6)
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(iptDnatV4(7100))
|
|
||||||
require.NoError(t, err)
|
|
||||||
v4, _ = state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "real add still increments after phantom delete")
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIptablesDNAT_DoubleDeleteNoUnderflow verifies a second Delete on the same
|
|
||||||
// rule is a no-op.
|
|
||||||
func TestIptablesDNAT_DoubleDeleteNoUnderflow(t *testing.T) {
|
|
||||||
m := newIptRefcountManager(t, true)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(iptDnatV6(9083))
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, v6 := state.Counts()
|
|
||||||
require.Equal(t, 1, v6)
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "first delete")
|
|
||||||
_, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v6)
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "second delete must be no-op")
|
|
||||||
_, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v6, "double delete must not underflow")
|
|
||||||
}
|
|
||||||
@@ -89,7 +89,7 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 router, since
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
// Forwarding refcounter is per-family but shared between v4 and v6 routers.
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
m.router6.ipFwdState = m.router.ipFwdState
|
m.router6.ipFwdState = m.router.ipFwdState
|
||||||
|
|
||||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||||
@@ -402,33 +402,17 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.router.ipFwdState.RequestForwarding(false); err != nil {
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
return fmt.Errorf("enable IPv4 forwarding: %w", err)
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
}
|
|
||||||
// v6 only when the overlay actually has v6.
|
|
||||||
if m.router6 == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := m.router.ipFwdState.RequestForwarding(true); err != nil {
|
|
||||||
if rerr := m.router.ipFwdState.ReleaseForwarding(false); rerr != nil {
|
|
||||||
log.Warnf("rollback v4 forwarding: %v", rerr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("enable IPv6 forwarding: %w", err)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
var merr *multierror.Error
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(false); err != nil {
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv4 forwarding: %w", err))
|
|
||||||
}
|
}
|
||||||
if m.router6 != nil {
|
return nil
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(true); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv6 forwarding: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
// AddDNATRule adds a DNAT rule
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
|||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(wgIface.Name()),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -763,6 +763,10 @@ func (r *router) updateState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
return rule, nil
|
return rule, nil
|
||||||
@@ -837,16 +841,6 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|||||||
r.rules[key] = ruleInfo.rule
|
r.rules[key] = ruleInfo.rule
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.ipFwdState.RequestForwarding(r.v6); err != nil {
|
|
||||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
|
||||||
log.Errorf("rollback failed: %v", rollbackErr)
|
|
||||||
}
|
|
||||||
for key := range rules {
|
|
||||||
delete(r.rules, key)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("enable forwarding: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
r.updateState()
|
||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
@@ -867,15 +861,12 @@ func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
ruleKey := rule.ID()
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
_, hadDNAT := r.rules[ruleKey+dnatSuffix]
|
|
||||||
_, hadSNAT := r.rules[ruleKey+snatSuffix]
|
|
||||||
_, hadFWD := r.rules[ruleKey+fwdSuffix]
|
|
||||||
if !hadDNAT && !hadSNAT && !hadFWD {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||||
@@ -898,10 +889,6 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
delete(r.rules, ruleKey+fwdSuffix)
|
delete(r.rules, ruleKey+fwdSuffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.ipFwdState.ReleaseForwarding(r.v6); err != nil {
|
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
r.updateState()
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func nftRefcountIfaceV4() *iFaceMock {
|
|
||||||
return &iFaceMock{
|
|
||||||
NameFunc: func() string { return "wt-refcount" },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.96.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func nftRefcountIfaceDual() *iFaceMock {
|
|
||||||
return &iFaceMock{
|
|
||||||
NameFunc: func() string { return "wt-refcount" },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.96.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
|
||||||
IPv6: netip.MustParseAddr("fd00::1"),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNftRefcountManager(t *testing.T, dual bool) *Manager {
|
|
||||||
t.Helper()
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
var ifMock *iFaceMock
|
|
||||||
if dual {
|
|
||||||
ifMock = nftRefcountIfaceDual()
|
|
||||||
} else {
|
|
||||||
ifMock = nftRefcountIfaceV4()
|
|
||||||
}
|
|
||||||
m, err := Create(ifMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err, "create manager")
|
|
||||||
require.NoError(t, m.Init(nil), "init manager")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, m.Close(nil), "close manager")
|
|
||||||
})
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func dnatV4(port uint16) fw.ForwardRule {
|
|
||||||
return fw.ForwardRule{
|
|
||||||
Protocol: fw.ProtocolTCP,
|
|
||||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
|
||||||
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
|
||||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dnatV6(port uint16) fw.ForwardRule {
|
|
||||||
return fw.ForwardRule{
|
|
||||||
Protocol: fw.ProtocolTCP,
|
|
||||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
|
||||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
|
||||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNftablesDNAT_RefcountBalancedV4 verifies that Add/Delete pairs leave the
|
|
||||||
// v4 refcount at zero.
|
|
||||||
func TestNftablesDNAT_RefcountBalancedV4(t *testing.T) {
|
|
||||||
m := newNftRefcountManager(t, false)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(dnatV4(8081))
|
|
||||||
require.NoError(t, err, "add v4 dnat 1")
|
|
||||||
v4, v6 := state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "v4 refcount after first add")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
|
|
||||||
r2, err := m.AddDNATRule(dnatV4(8082))
|
|
||||||
require.NoError(t, err, "add v4 dnat 2")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 2, v4, "v4 refcount after second add")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "delete v4 dnat 1")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "v4 refcount after first delete")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r2), "delete v4 dnat 2")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount after second delete")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNftablesDNAT_RefcountBalancedV6 verifies the v6 path increments v6 only
|
|
||||||
// and decrements back to zero on Delete.
|
|
||||||
func TestNftablesDNAT_RefcountBalancedV6(t *testing.T) {
|
|
||||||
m := newNftRefcountManager(t, true)
|
|
||||||
require.NotNil(t, m.router6, "v6 router")
|
|
||||||
require.Same(t, m.router.ipFwdState, m.router6.ipFwdState, "shared state")
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(dnatV6(9091))
|
|
||||||
require.NoError(t, err, "add v6 dnat 1")
|
|
||||||
v4, v6 := state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
|
||||||
require.Equal(t, 1, v6, "v6 refcount after first add")
|
|
||||||
|
|
||||||
r2, err := m.AddDNATRule(dnatV6(9092))
|
|
||||||
require.NoError(t, err, "add v6 dnat 2")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 2, v6, "v6 refcount after second add")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "delete v6 dnat 1")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
|
||||||
require.Equal(t, 1, v6, "v6 refcount after first delete")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r2), "delete v6 dnat 2")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount after second delete")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNftablesDNAT_DuplicateAddNoLeak verifies that a duplicate Add (same
|
|
||||||
// ForwardRule) does not double-increment the refcount.
|
|
||||||
func TestNftablesDNAT_DuplicateAddNoLeak(t *testing.T) {
|
|
||||||
m := newNftRefcountManager(t, true)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
rule := dnatV4(8083)
|
|
||||||
r1, err := m.AddDNATRule(rule)
|
|
||||||
require.NoError(t, err, "add v4 dnat")
|
|
||||||
v4, _ := state.Counts()
|
|
||||||
require.Equal(t, 1, v4)
|
|
||||||
|
|
||||||
// duplicate add: same rule ID, must be a no-op for the refcount.
|
|
||||||
_, err = m.AddDNATRule(rule)
|
|
||||||
require.NoError(t, err, "duplicate add")
|
|
||||||
v4, _ = state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "duplicate add must not increment")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "delete v4 dnat")
|
|
||||||
v4, _ = state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "single delete must drop to zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNftablesDNAT_DeleteMissingNoUnderflow verifies deleting a rule that was
|
|
||||||
// never added does not underflow the refcount.
|
|
||||||
func TestNftablesDNAT_DeleteMissingNoUnderflow(t *testing.T) {
|
|
||||||
m := newNftRefcountManager(t, true)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
// Construct a Rule reference for something never added. The router stores
|
|
||||||
// rules by ID(), and DeleteDNATRule looks them up in r.rules; a missing
|
|
||||||
// entry must be a no-op rather than calling Release.
|
|
||||||
phantom := dnatV4(8099)
|
|
||||||
require.NoError(t, m.DeleteDNATRule(&phantom), "delete missing v4 dnat")
|
|
||||||
v4, v6 := state.Counts()
|
|
||||||
require.Equal(t, 0, v4, "v4 refcount unaffected by missing delete")
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unaffected")
|
|
||||||
|
|
||||||
phantom6 := dnatV6(9099)
|
|
||||||
require.NoError(t, m.DeleteDNATRule(&phantom6), "delete missing v6 dnat")
|
|
||||||
v4, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v4)
|
|
||||||
require.Equal(t, 0, v6, "v6 refcount unaffected by missing delete")
|
|
||||||
|
|
||||||
// And after a phantom delete, a real add still results in count=1.
|
|
||||||
r1, err := m.AddDNATRule(dnatV4(8100))
|
|
||||||
require.NoError(t, err, "add v4 dnat after phantom delete")
|
|
||||||
v4, _ = state.Counts()
|
|
||||||
require.Equal(t, 1, v4, "real add still increments after phantom delete")
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNftablesDNAT_DoubleDeleteNoUnderflow verifies that deleting the same rule
|
|
||||||
// twice does not underflow the refcount (the second delete is a no-op).
|
|
||||||
func TestNftablesDNAT_DoubleDeleteNoUnderflow(t *testing.T) {
|
|
||||||
m := newNftRefcountManager(t, true)
|
|
||||||
state := m.router.ipFwdState
|
|
||||||
|
|
||||||
r1, err := m.AddDNATRule(dnatV6(9093))
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, v6 := state.Counts()
|
|
||||||
require.Equal(t, 1, v6)
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "first delete")
|
|
||||||
_, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v6)
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteDNATRule(r1), "second delete must be no-op")
|
|
||||||
_, v6 = state.Counts()
|
|
||||||
require.Equal(t, 0, v6, "double delete must not underflow")
|
|
||||||
}
|
|
||||||
@@ -105,8 +105,8 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
|
|||||||
return fmt.Errorf("create v6 router: %w", err)
|
return fmt.Errorf("create v6 router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Share the per-family forwarding refcounter with the v4 router so a v4
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
// rule and a v6 rule against the same state machine cooperate cleanly.
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
m.router6.ipFwdState = m.router.ipFwdState
|
m.router6.ipFwdState = m.router.ipFwdState
|
||||||
|
|
||||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||||
@@ -530,33 +530,17 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.router.ipFwdState.RequestForwarding(false); err != nil {
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
return fmt.Errorf("enable IPv4 forwarding: %w", err)
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
}
|
|
||||||
// v6 only when the overlay actually has v6.
|
|
||||||
if m.router6 == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := m.router.ipFwdState.RequestForwarding(true); err != nil {
|
|
||||||
if rerr := m.router.ipFwdState.ReleaseForwarding(false); rerr != nil {
|
|
||||||
log.Warnf("rollback v4 forwarding: %v", rerr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("enable IPv6 forwarding: %w", err)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
var merr *multierror.Error
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(false); err != nil {
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv4 forwarding: %w", err))
|
|
||||||
}
|
}
|
||||||
if m.router6 != nil {
|
return nil
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(true); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv6 forwarding: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(wgIface.Name()),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1550,6 +1550,10 @@ func (r *router) refreshRulesMap() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
return rule, nil
|
return rule, nil
|
||||||
@@ -1560,18 +1564,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request forwarding before queueing rules: addDnatRedirect/addDnatMasq
|
|
||||||
// buffer netlink messages on r.conn that the next caller's Flush would
|
|
||||||
// commit if we returned without flushing them ourselves.
|
|
||||||
v6 := r.af.tableFamily == nftables.TableFamilyIPv6
|
|
||||||
if err := r.ipFwdState.RequestForwarding(v6); err != nil {
|
|
||||||
return nil, fmt.Errorf("enable forwarding: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||||
if rerr := r.ipFwdState.ReleaseForwarding(v6); rerr != nil {
|
|
||||||
log.Warnf("rollback forwarding refcount: %v", rerr)
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1583,11 +1576,6 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|||||||
// TODO: find chains with drop policies and add rules there
|
// TODO: find chains with drop policies and add rules there
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
if rerr := r.ipFwdState.ReleaseForwarding(v6); rerr != nil {
|
|
||||||
log.Warnf("rollback forwarding refcount: %v", rerr)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleKey+dnatSuffix)
|
|
||||||
delete(r.rules, ruleKey+snatSuffix)
|
|
||||||
return nil, fmt.Errorf("flush rules: %w", err)
|
return nil, fmt.Errorf("flush rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1790,18 +1778,16 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, hadDNAT := r.rules[ruleKey+dnatSuffix]
|
|
||||||
_, hadSNAT := r.rules[ruleKey+snatSuffix]
|
|
||||||
if !hadDNAT && !hadSNAT {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
var needsFlush bool
|
var needsFlush bool
|
||||||
|
|
||||||
@@ -1838,10 +1824,6 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
delete(r.rules, ruleKey+snatSuffix)
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.ipFwdState.ReleaseForwarding(r.af.tableFamily == nftables.TableFamilyIPv6); err != nil {
|
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -844,10 +844,6 @@ func collectSysctls() string {
|
|||||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||||
))
|
))
|
||||||
writeSysctlGroup(&builder, "accept_ra", append(
|
|
||||||
[]string{"net.ipv6.conf.all.accept_ra", "net.ipv6.conf.default.accept_ra"},
|
|
||||||
listInterfaceSysctls("ipv6", "accept_ra")...,
|
|
||||||
))
|
|
||||||
writeSysctlGroup(&builder, "conntrack", []string{
|
writeSysctlGroup(&builder, "conntrack", []string{
|
||||||
"net.netfilter.nf_conntrack_acct",
|
"net.netfilter.nf_conntrack_acct",
|
||||||
"net.netfilter.nf_conntrack_tcp_loose",
|
"net.netfilter.nf_conntrack_tcp_loose",
|
||||||
|
|||||||
@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
return true
|
return true
|
||||||
case entry.IsWildcard:
|
case entry.IsWildcard:
|
||||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
return strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
|
||||||
default:
|
default:
|
||||||
// For non-wildcard patterns:
|
// For non-wildcard patterns:
|
||||||
// If handler wants subdomain matching, allow suffix match
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
|||||||
@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
matchSubdomains: true,
|
matchSubdomains: true,
|
||||||
shouldMatch: true,
|
shouldMatch: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard label-boundary mismatch (suffix overlap)",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "x.ab.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard label-boundary match",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "x.b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard multi-label match",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "x.y.b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on multi-label apex",
|
||||||
|
handlerDomain: "*.b.test.",
|
||||||
|
queryDomain: "b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on unrelated suffix containment",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "notexample.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard accepts pattern registered without trailing dot",
|
||||||
|
handlerDomain: "*.b.test",
|
||||||
|
queryDomain: "x.b.test.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
expectedCalls: 1,
|
expectedCalls: 1,
|
||||||
expectedHandler: 2, // highest priority matching handler should be called
|
expectedHandler: 2, // highest priority matching handler should be called
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "overlapping wildcard suffixes route to correct handler",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "app.ab.test.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "root zone with specific domain",
|
name: "root zone with specific domain",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
|
|||||||
@@ -26,6 +26,19 @@ type resolver interface {
|
|||||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
|
||||||
|
// client knows about and whether that peer is currently connected. The
|
||||||
|
// local resolver uses this to suppress A/AAAA answers whose RDATA points
|
||||||
|
// at a disconnected peer (typical case: a synthesized private-service
|
||||||
|
// record pointing at an embedded proxy peer that just went offline).
|
||||||
|
//
|
||||||
|
// known=false means the IP isn't in the local peerstore at all — the
|
||||||
|
// record is left alone (it points at something outside our mesh, e.g.
|
||||||
|
// a non-peer upstream).
|
||||||
|
type PeerConnectivity interface {
|
||||||
|
IsConnectedByIP(ip string) (known, connected bool)
|
||||||
|
}
|
||||||
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
records map[dns.Question][]dns.RR
|
records map[dns.Question][]dns.RR
|
||||||
@@ -33,6 +46,11 @@ type Resolver struct {
|
|||||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||||
zones map[domain.Domain]bool
|
zones map[domain.Domain]bool
|
||||||
resolver resolver
|
resolver resolver
|
||||||
|
// peerConn, when non-nil, is consulted on every A/AAAA answer to
|
||||||
|
// drop records pointing at disconnected peers. nil disables the
|
||||||
|
// filter and preserves the legacy "return whatever is registered"
|
||||||
|
// behaviour for callers that never wire a status source.
|
||||||
|
peerConn PeerConnectivity
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPeerConnectivity wires the per-IP connectivity check used to filter
|
||||||
|
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
|
||||||
|
// Safe to call multiple times; the latest value wins.
|
||||||
|
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
d.peerConn = p
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Resolver) MatchSubdomains() bool {
|
func (d *Resolver) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
|
||||||
result := d.lookupRecords(logger, question)
|
result := d.lookupRecords(logger, question)
|
||||||
|
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
|
||||||
replyMessage.Authoritative = !result.hasExternalData
|
replyMessage.Authoritative = !result.hasExternalData
|
||||||
replyMessage.Answer = result.records
|
replyMessage.Answer = result.records
|
||||||
replyMessage.Rcode = d.determineRcode(question, result)
|
replyMessage.Rcode = d.determineRcode(question, result)
|
||||||
@@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
|
||||||
|
// a known but disconnected peer. The synthesized private-service zones
|
||||||
|
// emit one A record per connected proxy peer in a cluster; when a peer
|
||||||
|
// goes offline, the server-side refresh removes the record from the
|
||||||
|
// next netmap, but the client may still hold the previous netmap for a
|
||||||
|
// short window. This filter is the local belt to that braces — even on
|
||||||
|
// the stale netmap, the resolver hides the offline target.
|
||||||
|
//
|
||||||
|
// Records pointing at unknown IPs (outside the local peerstore, e.g.
|
||||||
|
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
|
||||||
|
// through untouched.
|
||||||
|
//
|
||||||
|
// Escape hatch: if filtering would leave the answer empty AND at least
|
||||||
|
// one record was filtered, the original list is returned. Better to
|
||||||
|
// hand the client a record that may not respond than NXDOMAIN it
|
||||||
|
// completely when every proxy peer is offline (the upstream may still
|
||||||
|
// be reachable some other way, or the peerstore may be stale).
|
||||||
|
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
d.mu.RLock()
|
||||||
|
checker := d.peerConn
|
||||||
|
d.mu.RUnlock()
|
||||||
|
if checker == nil {
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
kept := make([]dns.RR, 0, len(records))
|
||||||
|
var dropped int
|
||||||
|
for _, rr := range records {
|
||||||
|
ip := extractRecordIP(rr)
|
||||||
|
if ip == "" {
|
||||||
|
kept = append(kept, rr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
known, connected := checker.IsConnectedByIP(ip)
|
||||||
|
if known && !connected {
|
||||||
|
dropped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, rr)
|
||||||
|
}
|
||||||
|
if dropped == 0 {
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
if len(kept) == 0 {
|
||||||
|
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
|
||||||
|
return kept
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
|
||||||
|
// an A or AAAA record, or "" for any other record type.
|
||||||
|
func extractRecordIP(rr dns.RR) string {
|
||||||
|
switch r := rr.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
if r.A == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return r.A.String()
|
||||||
|
case *dns.AAAA:
|
||||||
|
if r.AAAA == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return r.AAAA.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Update replaces all zones and their records
|
// Update replaces all zones and their records
|
||||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
|
|||||||
@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mockPeerConnectivity returns canned (known, connected) results per IP.
|
||||||
|
// Used by the disconnected-peer filter tests below. IPs not in the map
|
||||||
|
// are reported as unknown so the filter leaves them alone.
|
||||||
|
type mockPeerConnectivity struct {
|
||||||
|
byIP map[string]struct{ known, connected bool }
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||||
|
v, ok := m.byIP[ip]
|
||||||
|
if !ok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return v.known, v.connected
|
||||||
|
}
|
||||||
|
|
||||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||||
recordA := nbdns.SimpleRecord{
|
recordA := nbdns.SimpleRecord{
|
||||||
Name: "peera.netbird.cloud.",
|
Name: "peera.netbird.cloud.",
|
||||||
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
|||||||
resolver.isInManagedZone(qname)
|
resolver.isInManagedZone(qname)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
|
||||||
|
// connectivity-aware filtering layered on top of lookupRecords:
|
||||||
|
// when an A record's IP belongs to a known peer that's disconnected,
|
||||||
|
// the record is dropped from the answer. Records for unknown IPs pass
|
||||||
|
// through. If filtering would empty the answer entirely and at least
|
||||||
|
// one record was dropped, the original list is restored (escape hatch
|
||||||
|
// for the "all proxies offline" case).
|
||||||
|
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||||
|
zone := "svc.cluster.netbird."
|
||||||
|
connectedRec := nbdns.SimpleRecord{
|
||||||
|
Name: zone,
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 5,
|
||||||
|
RData: "100.64.0.10",
|
||||||
|
}
|
||||||
|
disconnectedRec := nbdns.SimpleRecord{
|
||||||
|
Name: zone,
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 5,
|
||||||
|
RData: "100.64.0.11",
|
||||||
|
}
|
||||||
|
unknownRec := nbdns.SimpleRecord{
|
||||||
|
Name: zone,
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 5,
|
||||||
|
RData: "203.0.113.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipState struct{ known, connected bool }
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
records []nbdns.SimpleRecord
|
||||||
|
connByIP map[string]ipState
|
||||||
|
wantInOrder []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "drops disconnected peer, keeps connected",
|
||||||
|
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||||
|
connByIP: map[string]ipState{
|
||||||
|
"100.64.0.10": {known: true, connected: true},
|
||||||
|
"100.64.0.11": {known: true, connected: false},
|
||||||
|
},
|
||||||
|
wantInOrder: []string{"100.64.0.10"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown IPs pass through untouched",
|
||||||
|
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
|
||||||
|
connByIP: map[string]ipState{
|
||||||
|
"100.64.0.11": {known: true, connected: false},
|
||||||
|
},
|
||||||
|
wantInOrder: []string{"203.0.113.5"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all disconnected falls back to original list",
|
||||||
|
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
|
||||||
|
connByIP: map[string]ipState{
|
||||||
|
"100.64.0.10": {known: true, connected: false},
|
||||||
|
"100.64.0.11": {known: true, connected: false},
|
||||||
|
},
|
||||||
|
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no checker wired returns all records",
|
||||||
|
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||||
|
connByIP: nil,
|
||||||
|
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
if tc.connByIP != nil {
|
||||||
|
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
|
||||||
|
for ip, st := range tc.connByIP {
|
||||||
|
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
|
||||||
|
}
|
||||||
|
resolver.SetPeerConnectivity(cm)
|
||||||
|
}
|
||||||
|
resolver.Update([]nbdns.CustomZone{{
|
||||||
|
Domain: strings.TrimSuffix(zone, "."),
|
||||||
|
Records: tc.records,
|
||||||
|
NonAuthoritative: true,
|
||||||
|
}})
|
||||||
|
|
||||||
|
var got *dns.Msg
|
||||||
|
writer := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
got = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
|
||||||
|
resolver.ServeDNS(writer, req)
|
||||||
|
|
||||||
|
require.NotNil(t, got, "resolver must produce a response")
|
||||||
|
require.Len(t, got.Answer, len(tc.wantInOrder),
|
||||||
|
"answer count must match expected: %v", tc.wantInOrder)
|
||||||
|
for i, want := range tc.wantInOrder {
|
||||||
|
a, ok := got.Answer[i].(*dns.A)
|
||||||
|
require.True(t, ok, "answer[%d] must be an A record", i)
|
||||||
|
assert.Equal(t, want, a.A.String(),
|
||||||
|
"answer[%d] expected %s got %s", i, want, a.A.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -301,6 +301,11 @@ func newDefaultServer(
|
|||||||
warningDelayBase: defaultWarningDelayBase,
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
healthRefresh: make(chan struct{}, 1),
|
healthRefresh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
// Wire the local resolver against the peer status recorder so it can
|
||||||
|
// suppress A/AAAA answers that point at disconnected peers (typical
|
||||||
|
// case: synthesised private-service records pointing at an embedded
|
||||||
|
// proxy peer that just went offline).
|
||||||
|
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
|
||||||
|
|
||||||
// register with root zone, handler chain takes care of the routing
|
// register with root zone, handler chain takes care of the routing
|
||||||
dnsService.RegisterMux(".", handlerChain)
|
dnsService.RegisterMux(".", handlerChain)
|
||||||
@@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
|
||||||
|
// the local resolver can ask "is this IP a known peer and is it
|
||||||
|
// connected?" without taking on the peer package as a dependency.
|
||||||
|
// A nil status recorder always reports known=false so the resolver
|
||||||
|
// short-circuits to the legacy "return everything" path.
|
||||||
|
type localPeerConnectivity struct {
|
||||||
|
status *peer.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
|
||||||
|
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
|
||||||
|
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||||
|
if l.status == nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
state, ok := l.status.PeerStateByIP(ip)
|
||||||
|
if !ok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return true, state.ConnStatus == peer.StatusConnected
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,9 +61,11 @@ import (
|
|||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
types "github.com/netbirdio/netbird/shared/management/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||||
@@ -202,6 +204,13 @@ type Engine struct {
|
|||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
|
|
||||||
|
// latestComponents is the most-recent NetworkMapComponents decoded from
|
||||||
|
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
|
||||||
|
// NetworkMap that Calculate() produced from it so future incremental
|
||||||
|
// updates have a base to apply changes against. nil for legacy-format
|
||||||
|
// peers. Guarded by syncMsgMux.
|
||||||
|
latestComponents *types.NetworkMapComponents
|
||||||
|
|
||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
@@ -865,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return e.ctx.Err()
|
return e.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
// Envelope sync responses carry PeerConfig at the top level; legacy
|
||||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
|
||||||
|
if pc := update.GetPeerConfig(); pc != nil {
|
||||||
|
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
|
||||||
|
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
|
||||||
|
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
@@ -907,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nm := update.GetNetworkMap()
|
var (
|
||||||
|
nm *mgmProto.NetworkMap
|
||||||
|
components *types.NetworkMapComponents
|
||||||
|
)
|
||||||
|
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
|
||||||
|
// Components-format peer: decode the envelope back to typed
|
||||||
|
// components, run Calculate() locally, and convert to the wire
|
||||||
|
// NetworkMap shape the rest of the engine consumes. Components are
|
||||||
|
// retained so future incremental updates can apply deltas instead
|
||||||
|
// of doing a full reconstruction.
|
||||||
|
localKey := e.config.WgPrivateKey.PublicKey().String()
|
||||||
|
dnsName := ""
|
||||||
|
if pc := update.GetPeerConfig(); pc != nil {
|
||||||
|
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
|
||||||
|
// shared domain by stripping the peer's own label prefix. Falls
|
||||||
|
// back to empty if the FQDN doesn't have the expected shape.
|
||||||
|
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
|
||||||
|
}
|
||||||
|
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decode network map envelope: %w", err)
|
||||||
|
}
|
||||||
|
nm = result.NetworkMap
|
||||||
|
components = result.Components
|
||||||
|
} else {
|
||||||
|
nm = update.GetNetworkMap()
|
||||||
|
}
|
||||||
if nm == nil {
|
if nm == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only retain the components view when the server sent the envelope
|
||||||
|
// path. A legacy proto.NetworkMap means components == nil; writing it
|
||||||
|
// here would clobber a previously-cached snapshot, breaking the
|
||||||
|
// incremental-delta base on a future envelope sync.
|
||||||
|
if components != nil {
|
||||||
|
e.latestComponents = components
|
||||||
|
}
|
||||||
|
|
||||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||||
// Read the storage-enabled flag under the syncRespMux too.
|
// Read the storage-enabled flag under the syncRespMux too.
|
||||||
e.syncRespMux.RLock()
|
e.syncRespMux.RLock()
|
||||||
@@ -937,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
|
||||||
|
// receiving peer's FQDN — the same value the management server fills as
|
||||||
|
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
|
||||||
|
// "netbird.cloud". An empty string is returned for unrecognized formats.
|
||||||
|
func extractDNSDomainFromFQDN(fqdn string) string {
|
||||||
|
for i := 0; i < len(fqdn); i++ {
|
||||||
|
if fqdn[i] == '.' && i+1 < len(fqdn) {
|
||||||
|
return fqdn[i+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||||
if update != nil {
|
if update != nil {
|
||||||
// when we receive token we expect valid address list too
|
// when we receive token we expect valid address list too
|
||||||
@@ -1967,6 +2027,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
|||||||
return e.clientMetrics
|
return e.clientMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||||
|
// See Engine.SetPerformance. Nil fields are ignored.
|
||||||
|
type Performance struct {
|
||||||
|
PreallocatedBuffersPerPool *uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPerformance applies the given tuning to this engine's live Device.
|
||||||
|
func (e *Engine) SetPerformance(t Performance) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return fmt.Errorf("wg interface not initialized")
|
||||||
|
}
|
||||||
|
dev := e.wgInterface.GetWGDevice()
|
||||||
|
if dev == nil {
|
||||||
|
return fmt.Errorf("wg device not initialized")
|
||||||
|
}
|
||||||
|
if t.PreallocatedBuffersPerPool != nil {
|
||||||
|
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
@@ -66,8 +66,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
// handle route changes
|
// handle route changes
|
||||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
route, err := parseRouteMessage(buf[:n])
|
route, flags, err := parseRouteMessage(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
}
|
}
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case unix.RTM_ADD:
|
case unix.RTM_ADD:
|
||||||
|
if systemops.IgnoreAddedDefaultRoute(flags) {
|
||||||
|
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
|
||||||
|
continue
|
||||||
|
}
|
||||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
return nil
|
return nil
|
||||||
case unix.RTM_DELETE:
|
case unix.RTM_DELETE:
|
||||||
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
return nil, 0, fmt.Errorf("parse RIB: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(msgs) != 1 {
|
if len(msgs) != 1 {
|
||||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, ok := msgs[0].(*route.RouteMessage)
|
msg, ok := msgs[0].(*route.RouteMessage)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
return systemops.MsgToRoute(msg)
|
r, err := systemops.MsgToRoute(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return r, msg.Flags, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to deterministic key if no NetBird PSK is configured
|
// Fallback to deterministic key if no NetBird PSK is configured
|
||||||
determKey, err := conn.rosenpassDetermKey()
|
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
return determKey
|
return determKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: move this logic into Rosenpass package
|
|
||||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
|
||||||
lk := []byte(conn.config.LocalKey)
|
|
||||||
rk := []byte(conn.config.Key) // remote key
|
|
||||||
var keyInput []byte
|
|
||||||
if string(lk) > string(rk) {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(lk[:16], rk[:16]...)
|
|
||||||
} else {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(rk[:16], lk[:16]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := wgtypes.NewKey(keyInput)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
|||||||
return s.eventsChan
|
return s.eventsChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status holds a state of peers, signal, management connections and relays
|
// Status holds a state of peers, signal, management connections and relays.
|
||||||
|
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
|
||||||
|
// every private-service request) don't contend against each other.
|
||||||
|
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||||
type Status struct {
|
type Status struct {
|
||||||
mux sync.Mutex
|
mux sync.RWMutex
|
||||||
peers map[string]State
|
peers map[string]State
|
||||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||||
signalState bool
|
signalState bool
|
||||||
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
|||||||
|
|
||||||
// GetPeer adds peer to Daemon status map
|
// GetPeer adds peer to Daemon status map
|
||||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
state, ok := d.peers[peerPubKey]
|
state, ok := d.peers[peerPubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
for _, state := range d.peers {
|
for _, state := range d.peers {
|
||||||
if state.IP == ip {
|
if state.IP == ip {
|
||||||
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||||
|
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||||
|
// address so dual-stack peers are reachable on either family. Returns the
|
||||||
|
// zero State and false when no peer matches or the input is empty.
|
||||||
|
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||||
|
if ip == "" {
|
||||||
|
return State{}, false
|
||||||
|
}
|
||||||
|
d.mux.RLock()
|
||||||
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
|
for _, state := range d.peers {
|
||||||
|
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||||
|
return state, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return State{}, false
|
||||||
|
}
|
||||||
|
|
||||||
// RemovePeer removes peer from Daemon status map
|
// RemovePeer removes peer from Daemon status map
|
||||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
|||||||
|
|
||||||
// GetLocalPeerState returns the local peer state
|
// GetLocalPeerState returns the local peer state
|
||||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return d.localPeer.Clone()
|
return d.localPeer.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetRosenpassState() RosenpassState {
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return RosenpassState{
|
return RosenpassState{
|
||||||
d.rosenpassEnabled,
|
d.rosenpassEnabled,
|
||||||
d.rosenpassPermissive,
|
d.rosenpassPermissive,
|
||||||
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetLazyConnection() bool {
|
func (d *Status) GetLazyConnection() bool {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return d.lazyConnectionEnabled
|
return d.lazyConnectionEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetManagementState() ManagementState {
|
func (d *Status) GetManagementState() ManagementState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return ManagementState{
|
return ManagementState{
|
||||||
d.mgmAddress,
|
d.mgmAddress,
|
||||||
d.managementState,
|
d.managementState,
|
||||||
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
|||||||
|
|
||||||
// IsLoginRequired determines if a peer's login has expired.
|
// IsLoginRequired determines if a peer's login has expired.
|
||||||
func (d *Status) IsLoginRequired() bool {
|
func (d *Status) IsLoginRequired() bool {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
// if peer is connected to the management then login is not expired
|
// if peer is connected to the management then login is not expired
|
||||||
if d.managementState {
|
if d.managementState {
|
||||||
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetSignalState() SignalState {
|
func (d *Status) GetSignalState() SignalState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return SignalState{
|
return SignalState{
|
||||||
d.signalAddress,
|
d.signalAddress,
|
||||||
d.signalState,
|
d.signalState,
|
||||||
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
|
|||||||
|
|
||||||
// GetRelayStates returns the stun/turn/permanent relay states
|
// GetRelayStates returns the stun/turn/permanent relay states
|
||||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
if d.relayMgr == nil {
|
if d.relayMgr == nil {
|
||||||
return d.relayStates
|
return d.relayStates
|
||||||
}
|
}
|
||||||
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
if d.ingressGwMgr == nil {
|
if d.ingressGwMgr == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
// shallow copy is good enough, as slices fields are currently not updated
|
// shallow copy is good enough, as slices fields are currently not updated
|
||||||
return slices.Clone(d.nsGroupStates)
|
return slices.Clone(d.nsGroupStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
return maps.Clone(d.resolvedDomainsStates)
|
return maps.Clone(d.resolvedDomainsStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
|
|||||||
LazyConnectionEnabled: d.GetLazyConnection(),
|
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
|
|
||||||
fullStatus.LocalPeerState = d.localPeer
|
fullStatus.LocalPeerState = d.localPeer
|
||||||
|
|
||||||
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||||
d.mux.Lock()
|
d.mux.RLock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.RUnlock()
|
||||||
if d.wgIface == nil {
|
if d.wgIface == nil {
|
||||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
|
|||||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatus_PeerStateByIP(t *testing.T) {
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
req := require.New(t)
|
||||||
|
|
||||||
|
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
|
||||||
|
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
|
||||||
|
|
||||||
|
state, ok := status.PeerStateByIP("100.64.0.10")
|
||||||
|
req.True(ok, "known tunnel IP should resolve to a peer state")
|
||||||
|
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||||
|
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
|
||||||
|
|
||||||
|
_, ok = status.PeerStateByIP("100.64.0.99")
|
||||||
|
req.False(ok, "unknown IP must report ok=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
req := require.New(t)
|
||||||
|
|
||||||
|
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||||
|
|
||||||
|
state, ok := status.PeerStateByIP("fd00::1")
|
||||||
|
req.True(ok, "IPv6-only match must resolve to the peer state")
|
||||||
|
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||||
|
}
|
||||||
|
|
||||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
fqdn := "peer-a.netbird.local"
|
fqdn := "peer-a.netbird.local"
|
||||||
|
|||||||
@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
|
|||||||
return hex.EncodeToString(hasher.Sum(nil))
|
return hex.EncodeToString(hasher.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
|
||||||
|
// so tests can substitute a mock without spinning up a real UDP server.
|
||||||
|
type rpServer interface {
|
||||||
|
AddPeer(rp.PeerConfig) (rp.PeerID, error)
|
||||||
|
RemovePeer(rp.PeerID) error
|
||||||
|
Run() error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
ifaceName string
|
ifaceName string
|
||||||
spk []byte
|
spk []byte
|
||||||
@@ -36,7 +45,7 @@ type Manager struct {
|
|||||||
preSharedKey *[32]byte
|
preSharedKey *[32]byte
|
||||||
rpPeerIDs map[string]*rp.PeerID
|
rpPeerIDs map[string]*rp.PeerID
|
||||||
rpWgHandler *NetbirdHandler
|
rpWgHandler *NetbirdHandler
|
||||||
server *rp.Server
|
server rpServer
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
port int
|
port int
|
||||||
wgIface PresharedKeySetter
|
wgIface PresharedKeySetter
|
||||||
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
|||||||
|
|
||||||
rpKeyHash := hashRosenpassKey(public)
|
rpKeyHash := hashRosenpassKey(public)
|
||||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
return &Manager{
|
||||||
|
ifaceName: wgIfaceName,
|
||||||
|
rpKeyHash: rpKeyHash,
|
||||||
|
spk: public,
|
||||||
|
ssk: secret,
|
||||||
|
preSharedKey: (*[32]byte)(preSharedKey),
|
||||||
|
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||||
|
// rpWgHandler is created here (instead of only in generateConfig) so it
|
||||||
|
// is never nil between NewManager and Run(). Otherwise an early
|
||||||
|
// OnConnected call (race observed on Android, issue #4341) panics on
|
||||||
|
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
|
||||||
|
// replace it with a fresh handler on each Run() to clear stale peer
|
||||||
|
// state from previous engine sessions.
|
||||||
|
rpWgHandler: NewNetbirdHandler(),
|
||||||
|
lock: sync.Mutex{},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetPubKey() []byte {
|
func (m *Manager) GetPubKey() []byte {
|
||||||
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
|
|||||||
|
|
||||||
// addPeer adds a new peer to the Rosenpass server
|
// addPeer adds a new peer to the Rosenpass server
|
||||||
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
||||||
|
// Defense in depth against issue #4341 (Android crash): if Run() has not
|
||||||
|
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
|
||||||
|
// error instead of panicking on nil-receiver dereference.
|
||||||
|
if m.server == nil {
|
||||||
|
return fmt.Errorf("rosenpass server not initialized")
|
||||||
|
}
|
||||||
|
if m.rpWgHandler == nil {
|
||||||
|
return fmt.Errorf("rosenpass wg handler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
||||||
if m.preSharedKey != nil {
|
if m.preSharedKey != nil {
|
||||||
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
|
|||||||
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
||||||
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
||||||
}
|
}
|
||||||
|
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
|
||||||
|
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
|
||||||
|
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
|
||||||
|
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
|
||||||
|
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
|
||||||
|
// IPv6 so its address family matches our listening socket.
|
||||||
|
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
|
||||||
|
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
|
||||||
|
pcfg.Endpoint.IP = v4.To16()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
peerID, err := m.server.AddPeer(pcfg)
|
peerID, err := m.server.AddPeer(pcfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.server, err = rp.NewUDPServer(conf)
|
server, err := rp.NewUDPServer(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
m.server = server
|
||||||
|
m.lock.Unlock()
|
||||||
|
|
||||||
log.Infof("starting rosenpass server on port %d", m.port)
|
log.Infof("starting rosenpass server on port %d", m.port)
|
||||||
|
|
||||||
return m.server.Run()
|
return server.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the Rosenpass server
|
// Close closes the Rosenpass server
|
||||||
func (m *Manager) Close() error {
|
func (m *Manager) Close() error {
|
||||||
if m.server != nil {
|
m.lock.Lock()
|
||||||
err := m.server.Close()
|
server := m.server
|
||||||
if err != nil {
|
m.server = nil
|
||||||
log.Errorf("failed closing local rosenpass server")
|
m.lock.Unlock()
|
||||||
}
|
if server == nil {
|
||||||
m.server = nil
|
return nil
|
||||||
|
}
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
log.Errorf("failed closing local rosenpass server: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,412 @@
|
|||||||
package rosenpass
|
package rosenpass
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
rp "cunicu.li/go-rosenpass"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// --- test doubles -----------------------------------------------------------
|
||||||
|
|
||||||
|
type addPeerCall struct {
|
||||||
|
cfg rp.PeerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type removePeerCall struct {
|
||||||
|
id rp.PeerID
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockServer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
addCalls []addPeerCall
|
||||||
|
removed []removePeerCall
|
||||||
|
nextID rp.PeerID
|
||||||
|
addErr error
|
||||||
|
removeErr error
|
||||||
|
closed bool
|
||||||
|
ran bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
|
||||||
|
if m.addErr != nil {
|
||||||
|
return rp.PeerID{}, m.addErr
|
||||||
|
}
|
||||||
|
// Increment a byte in nextID so distinct peers get distinct IDs.
|
||||||
|
m.nextID[0]++
|
||||||
|
return m.nextID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.removed = append(m.removed, removePeerCall{id: id})
|
||||||
|
return m.removeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||||
|
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||||
|
|
||||||
|
type setPSKCall struct {
|
||||||
|
peerKey string
|
||||||
|
psk wgtypes.Key
|
||||||
|
updateOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockIface struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls []setPSKCall
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestManager builds a Manager with deterministic spk so tie-break
|
||||||
|
// against a peer pubkey is controllable from tests. The provided spk byte
|
||||||
|
// becomes the first byte; remaining bytes are zero.
|
||||||
|
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
|
||||||
|
spk := make([]byte, 32)
|
||||||
|
spk[0] = spkFirstByte
|
||||||
|
return &Manager{
|
||||||
|
ifaceName: "wt0",
|
||||||
|
spk: spk,
|
||||||
|
ssk: make([]byte, 32),
|
||||||
|
rpKeyHash: "test-hash",
|
||||||
|
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||||
|
rpWgHandler: NewNetbirdHandler(),
|
||||||
|
server: mock,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
|
||||||
|
func validWGKey(t *testing.T, lastByte byte) string {
|
||||||
|
t.Helper()
|
||||||
|
var k wgtypes.Key
|
||||||
|
k[31] = lastByte
|
||||||
|
return k.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- pure helpers ----------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHashRosenpassKey_Deterministic(t *testing.T) {
|
||||||
|
key := []byte("hello-rosenpass")
|
||||||
|
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
|
||||||
|
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
|
||||||
|
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
|
||||||
|
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
|
||||||
|
// can only set, not delete, so do it manually with restore via t.Cleanup.
|
||||||
|
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
|
||||||
|
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if hadPrev {
|
||||||
|
_ = os.Setenv(defaultLogLevelVar, prev)
|
||||||
|
} else {
|
||||||
|
_ = os.Unsetenv(defaultLogLevelVar)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
require.Equal(t, defaultLog.String(), getLogLevel().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogLevel_Cases(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"debug": "DEBUG",
|
||||||
|
"info": "INFO",
|
||||||
|
"warn": "WARN",
|
||||||
|
"error": "ERROR",
|
||||||
|
"unknown": "INFO", // default fallback
|
||||||
|
}
|
||||||
|
for input, wantStr := range cases {
|
||||||
|
input, wantStr := input, wantStr
|
||||||
|
t.Run(input, func(t *testing.T) {
|
||||||
|
t.Setenv(defaultLogLevelVar, input)
|
||||||
|
require.Equal(t, wantStr, getLogLevel().String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||||
port, err := findRandomAvailableUDPPort()
|
port, err := findRandomAvailableUDPPort()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Greater(t, port, 0)
|
require.Greater(t, port, 0)
|
||||||
require.LessOrEqual(t, port, 65535)
|
require.LessOrEqual(t, port, 65535)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- addPeer ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv) // local spk lexicographically larger
|
||||||
|
|
||||||
|
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
|
||||||
|
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, srv.addCalls, 1)
|
||||||
|
|
||||||
|
ep := srv.addCalls[0].cfg.Endpoint
|
||||||
|
require.NotNil(t, ep, "initiator side must set Endpoint")
|
||||||
|
require.Equal(t, 7000, ep.Port)
|
||||||
|
require.Equal(t, "100.1.1.1", ep.IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
|
||||||
|
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
|
||||||
|
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ep := srv.addCalls[0].cfg.Endpoint
|
||||||
|
require.NotNil(t, ep)
|
||||||
|
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
|
||||||
|
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0x00, srv) // local spk smaller
|
||||||
|
|
||||||
|
remotePubKey := make([]byte, 32)
|
||||||
|
remotePubKey[0] = 0xFF
|
||||||
|
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
psk := &wgtypes.Key{0x42}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
m.preSharedKey = (*[32]byte)(psk)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_ServerError_Propagates(t *testing.T) {
|
||||||
|
srv := &mockServer{addErr: errors.New("boom")}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression guard for issue #4341 (Android crash). If Run() has not completed
|
||||||
|
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
|
||||||
|
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
|
||||||
|
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
m.rpWgHandler = nil // simulate Run() not yet completed
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "wg handler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
|
||||||
|
m := newTestManager(0xFF, nil)
|
||||||
|
m.server = nil // simulate Run() not yet completed
|
||||||
|
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "server not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
|
||||||
|
// issue #4341 cannot occur in the window between NewManager and Run().
|
||||||
|
func TestNewManager_PreInitializesHandler(t *testing.T) {
|
||||||
|
psk := wgtypes.Key{}
|
||||||
|
m, err := NewManager(&psk, "wt0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPeer_RecordsPeerID(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 5)
|
||||||
|
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OnConnected / OnDisconnected ------------------------------------------
|
||||||
|
|
||||||
|
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
|
||||||
|
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
|
||||||
|
require.Empty(t, m.rpPeerIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 1)
|
||||||
|
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
|
||||||
|
require.Len(t, srv.addCalls, 1)
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
m.OnDisconnected(validWGKey(t, 99))
|
||||||
|
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 1)
|
||||||
|
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||||
|
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||||
|
|
||||||
|
m.OnDisconnected(wgKey)
|
||||||
|
require.Len(t, srv.removed, 1)
|
||||||
|
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- IsPresharedKeyInitialized ---------------------------------------------
|
||||||
|
|
||||||
|
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
|
||||||
|
srv := &mockServer{}
|
||||||
|
m := newTestManager(0xFF, srv)
|
||||||
|
|
||||||
|
wgKey := validWGKey(t, 2)
|
||||||
|
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||||
|
require.False(t, m.IsPresharedKeyInitialized(wgKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- NetbirdHandler.outputKey ----------------------------------------------
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x01}
|
||||||
|
wgKey := wgtypes.Key{0xAA}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||||
|
|
||||||
|
psk := rp.Key{0xBB}
|
||||||
|
h.HandshakeCompleted(pid, psk)
|
||||||
|
|
||||||
|
require.Len(t, iface.calls, 1)
|
||||||
|
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
|
||||||
|
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x02}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
|
||||||
|
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
|
||||||
|
|
||||||
|
require.Len(t, iface.calls, 2)
|
||||||
|
require.False(t, iface.calls[0].updateOnly)
|
||||||
|
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
// no SetInterface — iface remains nil
|
||||||
|
pid := rp.PeerID{0x03}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
|
||||||
|
|
||||||
|
// Must not panic.
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
|
||||||
|
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface)
|
||||||
|
|
||||||
|
pid := rp.PeerID{0x04}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x01})
|
||||||
|
require.True(t, h.IsPeerInitialized(pid))
|
||||||
|
|
||||||
|
h.RemovePeer(pid)
|
||||||
|
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
|
||||||
|
h := NewNetbirdHandler()
|
||||||
|
pid := rp.PeerID{0x05}
|
||||||
|
wgKey := wgtypes.Key{0xEE}
|
||||||
|
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||||
|
|
||||||
|
iface := &mockIface{}
|
||||||
|
h.SetInterface(iface) // set after AddPeer
|
||||||
|
|
||||||
|
h.HandshakeCompleted(pid, rp.Key{0x42})
|
||||||
|
require.Len(t, iface.calls, 1)
|
||||||
|
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||||
|
}
|
||||||
|
|||||||
42
client/internal/rosenpass/seed.go
Normal file
42
client/internal/rosenpass/seed.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
|
||||||
|
// of peer public keys. Both peers, given the same key pair, produce the same
|
||||||
|
// output regardless of which side runs the function: the inputs are ordered
|
||||||
|
// lexicographically before concatenation.
|
||||||
|
//
|
||||||
|
// NetBird uses this value as the initial Rosenpass-side preshared key when no
|
||||||
|
// explicit account-level PSK is configured, so both peers converge on the same
|
||||||
|
// PSK before the first post-quantum handshake completes.
|
||||||
|
//
|
||||||
|
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
|
||||||
|
// from public keys and exists only to seed WireGuard until Rosenpass rotates
|
||||||
|
// in a real post-quantum PSK.
|
||||||
|
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
|
||||||
|
lk := []byte(localKey)
|
||||||
|
rk := []byte(remoteKey)
|
||||||
|
if len(lk) < 16 || len(rk) < 16 {
|
||||||
|
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyInput []byte
|
||||||
|
if localKey > remoteKey {
|
||||||
|
keyInput = append(keyInput, lk[:16]...)
|
||||||
|
keyInput = append(keyInput, rk[:16]...)
|
||||||
|
} else {
|
||||||
|
keyInput = append(keyInput, rk[:16]...)
|
||||||
|
keyInput = append(keyInput, lk[:16]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := wgtypes.NewKey(keyInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
|
||||||
|
}
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
44
client/internal/rosenpass/seed_test.go
Normal file
44
client/internal/rosenpass/seed_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
|
||||||
|
// Peer A and peer B must derive the same PSK regardless of which side
|
||||||
|
// computes it: the function orders inputs internally.
|
||||||
|
a := strings.Repeat("a", 32)
|
||||||
|
b := strings.Repeat("b", 32)
|
||||||
|
|
||||||
|
keyAB, err := DeterministicSeedKey(a, b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyBA, err := DeterministicSeedKey(b, a)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
|
||||||
|
a := strings.Repeat("a", 32)
|
||||||
|
b := strings.Repeat("b", 32)
|
||||||
|
c := strings.Repeat("c", 32)
|
||||||
|
|
||||||
|
keyAB, err := DeterministicSeedKey(a, b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyAC, err := DeterministicSeedKey(a, c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||||
|
short := "short" // < 16 bytes
|
||||||
|
long := strings.Repeat("x", 32)
|
||||||
|
|
||||||
|
_, err := DeterministicSeedKey(short, long)
|
||||||
|
require.Error(t, err)
|
||||||
|
_, err = DeterministicSeedKey(long, short)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
@@ -2,109 +2,54 @@ package ipfwdstate
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPForwardingState tracks v4 and v6 IP-forwarding sysctl enables with
|
// IPForwardingState is a struct that keeps track of the IP forwarding state.
|
||||||
// independent refcounts so a v4-only routing setup doesn't flip v6 sysctls.
|
// todo: read initial state of the IP forwarding from the system and reset the state based on it.
|
||||||
|
// todo: separate v4/v6 forwarding state, since the sysctls are independent
|
||||||
|
// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables
|
||||||
|
// manager shares one instance between both routers, which works only because
|
||||||
|
// EnableIPForwarding enables both sysctls in a single call.
|
||||||
type IPForwardingState struct {
|
type IPForwardingState struct {
|
||||||
mu sync.Mutex
|
enabledCounter int
|
||||||
|
|
||||||
v4Count int
|
|
||||||
v6Count int
|
|
||||||
|
|
||||||
wgIfaceName string
|
|
||||||
v6Saved map[string]int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIPForwardingState(wgIfaceName string) *IPForwardingState {
|
func NewIPForwardingState() *IPForwardingState {
|
||||||
return &IPForwardingState{wgIfaceName: wgIfaceName}
|
return &IPForwardingState{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Counts returns the current v4 and v6 refcounts. Intended for diagnostics
|
func (f *IPForwardingState) RequestForwarding() error {
|
||||||
// and tests.
|
if f.enabledCounter != 0 {
|
||||||
func (f *IPForwardingState) Counts() (v4, v6 int) {
|
f.enabledCounter++
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.v4Count, f.v6Count
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestForwarding enables the family's forwarding sysctl on first request.
|
|
||||||
func (f *IPForwardingState) RequestForwarding(v6 bool) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
|
|
||||||
if v6 {
|
|
||||||
return f.requestV6()
|
|
||||||
}
|
|
||||||
return f.requestV4()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseForwarding decrements the family counter. The last v6 release restores
|
|
||||||
// what enable captured. v4 stays on: net.ipv4.ip_forward is co-owned by other
|
|
||||||
// tooling (docker, k8s, libvirt).
|
|
||||||
func (f *IPForwardingState) ReleaseForwarding(v6 bool) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
|
|
||||||
if v6 {
|
|
||||||
return f.releaseV6()
|
|
||||||
}
|
|
||||||
f.releaseV4()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *IPForwardingState) requestV4() error {
|
|
||||||
if f.v4Count == 0 {
|
|
||||||
if err := systemops.EnableV4IPForwarding(); err != nil {
|
|
||||||
return fmt.Errorf("enable IPv4 forwarding: %w", err)
|
|
||||||
}
|
|
||||||
log.Info("IPv4 forwarding enabled")
|
|
||||||
}
|
|
||||||
f.v4Count++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *IPForwardingState) releaseV4() {
|
|
||||||
if f.v4Count > 0 {
|
|
||||||
f.v4Count--
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *IPForwardingState) requestV6() error {
|
|
||||||
if f.v6Count == 0 {
|
|
||||||
saved, err := systemops.EnableV6IPForwarding(f.wgIfaceName)
|
|
||||||
if err != nil {
|
|
||||||
if rerr := systemops.DisableV6IPForwarding(saved); rerr != nil {
|
|
||||||
log.Warnf("rollback partial v6 sysctls: %v", rerr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("enable IPv6 forwarding: %w", err)
|
|
||||||
}
|
|
||||||
f.v6Saved = saved
|
|
||||||
log.Info("IPv6 forwarding enabled")
|
|
||||||
}
|
|
||||||
f.v6Count++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *IPForwardingState) releaseV6() error {
|
|
||||||
if f.v6Count == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
f.v6Count--
|
|
||||||
if f.v6Count > 0 {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
saved := f.v6Saved
|
if err := systemops.EnableIPForwarding(); err != nil {
|
||||||
f.v6Saved = nil
|
return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
|
||||||
if err := systemops.DisableV6IPForwarding(saved); err != nil {
|
|
||||||
return fmt.Errorf("disable IPv6 forwarding: %w", err)
|
|
||||||
}
|
}
|
||||||
log.Info("IPv6 forwarding disabled")
|
f.enabledCounter = 1
|
||||||
|
log.Info("IP forwarding enabled")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *IPForwardingState) ReleaseForwarding() error {
|
||||||
|
if f.enabledCounter == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.enabledCounter > 1 {
|
||||||
|
f.enabledCounter--
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if failed to disable IP forwarding we anyway decrement the counter
|
||||||
|
f.enabledCounter = 0
|
||||||
|
|
||||||
|
// todo call systemops.DisableIPForwarding()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||||
|
// given flags should be ignored by the network monitor.
|
||||||
|
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||||
|
return filterRoutesByFlags(flags)
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||||
|
// given flags should be ignored by the network monitor. Scoped routes
|
||||||
|
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
|
||||||
|
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
|
||||||
|
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
|
||||||
|
// must not trigger an engine restart.
|
||||||
|
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||||
|
if filterRoutesByFlags(flags) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_IFSCOPE != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -32,17 +32,8 @@ func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableV4IPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func EnableV6IPForwarding(string) (map[string]int, error) {
|
|
||||||
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
|
|
||||||
return map[string]int{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DisableV6IPForwarding(map[string]int) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,17 +58,8 @@ func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableV4IPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func EnableV6IPForwarding(string) (map[string]int, error) {
|
|
||||||
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
|
|
||||||
return map[string]int{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DisableV6IPForwarding(map[string]int) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -763,10 +763,13 @@ func flushRoutes(tableID, family int) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableV4IPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
|
if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil {
|
||||||
|
log.Warnf("failed to enable IPv6 forwarding: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,17 +43,8 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
|
|||||||
return r.genericRemoveVPNRoute(prefix, intf)
|
return r.genericRemoveVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableV4IPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func EnableV6IPForwarding(string) (map[string]int, error) {
|
|
||||||
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
|
|
||||||
return map[string]int{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DisableV6IPForwarding(map[string]int) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package systemops
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// 1 (default) accepts RAs only while forwarding is off; 2 keeps RA
|
|
||||||
// acceptance on regardless, so RA-installed host defaults survive our
|
|
||||||
// v6 forwarding flip.
|
|
||||||
acceptRAInterfacePath = "net.ipv6.conf.%s.accept_ra"
|
|
||||||
acceptRAProcPathFormat = "/proc/sys/net/ipv6/conf/%s/accept_ra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// EnableV6IPForwarding bumps accept_ra=2 on host v6 interfaces before flipping
|
|
||||||
// forwarding=1, so RA-installed host defaults survive. Returns the prior values
|
|
||||||
// of sysctls we actually changed; entries already at the target are omitted.
|
|
||||||
func EnableV6IPForwarding(wgIfaceName string) (map[string]int, error) {
|
|
||||||
saved := map[string]int{}
|
|
||||||
bumpAcceptRA(saved, wgIfaceName)
|
|
||||||
|
|
||||||
oldVal, err := sysctl.Set(ipv6ForwardingPath, 1, false)
|
|
||||||
if err != nil {
|
|
||||||
return saved, err
|
|
||||||
}
|
|
||||||
if oldVal != 1 {
|
|
||||||
saved[ipv6ForwardingPath] = oldVal
|
|
||||||
}
|
|
||||||
return saved, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisableV6IPForwarding restores what EnableV6IPForwarding captured.
|
|
||||||
func DisableV6IPForwarding(saved map[string]int) error {
|
|
||||||
var result *multierror.Error
|
|
||||||
for key, value := range saved {
|
|
||||||
if _, err := sysctl.Set(key, value, false); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("restore %s: %w", key, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func bumpAcceptRA(saved map[string]int, wgIfaceName string) {
|
|
||||||
interfaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("list interfaces for accept_ra: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, intf := range interfaces {
|
|
||||||
if intf.Name == "lo" || intf.Name == wgIfaceName {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
bumpAcceptRAForInterface(saved, intf.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func bumpAcceptRAForInterface(saved map[string]int, name string) {
|
|
||||||
key := fmt.Sprintf(acceptRAInterfacePath, name)
|
|
||||||
// Build procfs path from name, not the dotted key: VLAN names like eth0.100.
|
|
||||||
if _, err := os.Stat(fmt.Sprintf(acceptRAProcPathFormat, name)); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// onlyIfOne=true: leave admin overrides (0, 2) alone.
|
|
||||||
oldVal, err := sysctl.Set(key, 2, true)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("bump %s: %v", key, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if oldVal != 2 {
|
|
||||||
saved[key] = oldVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
cancel := m.cancel
|
||||||
|
done := m.done
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
if m.cancel == nil {
|
if cancel == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.cancel()
|
cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-m.done:
|
case <-done:
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -3,15 +3,14 @@
|
|||||||
package system
|
package system
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/zcalusic/sysinfo"
|
"github.com/zcalusic/sysinfo"
|
||||||
|
|
||||||
@@ -29,19 +28,11 @@ func UpdateStaticInfoAsync() {
|
|||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context) *Info {
|
||||||
info := _getInfo()
|
kernelName, kernelVersion, kernelPlatform := kernelInfo()
|
||||||
for strings.Contains(info, "broken pipe") {
|
|
||||||
info = _getInfo()
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
osStr := strings.ReplaceAll(info, "\n", "")
|
|
||||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
|
||||||
osInfo := strings.Split(osStr, " ")
|
|
||||||
|
|
||||||
osName, osVersion := readOsReleaseFile()
|
osName, osVersion := readOsReleaseFile()
|
||||||
if osName == "" {
|
if osName == "" {
|
||||||
osName = osInfo[3]
|
osName = kernelName
|
||||||
}
|
}
|
||||||
|
|
||||||
systemHostname, _ := os.Hostname()
|
systemHostname, _ := os.Hostname()
|
||||||
@@ -58,8 +49,8 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gio := &Info{
|
gio := &Info{
|
||||||
Kernel: osInfo[0],
|
Kernel: kernelName,
|
||||||
Platform: osInfo[2],
|
Platform: kernelPlatform,
|
||||||
OS: osName,
|
OS: osName,
|
||||||
OSVersion: osVersion,
|
OSVersion: osVersion,
|
||||||
Hostname: extractDeviceName(ctx, systemHostname),
|
Hostname: extractDeviceName(ctx, systemHostname),
|
||||||
@@ -67,7 +58,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
CPUs: runtime.NumCPU(),
|
CPUs: runtime.NumCPU(),
|
||||||
NetbirdVersion: version.NetbirdVersion(),
|
NetbirdVersion: version.NetbirdVersion(),
|
||||||
UIVersion: extractUserAgent(ctx),
|
UIVersion: extractUserAgent(ctx),
|
||||||
KernelVersion: osInfo[1],
|
KernelVersion: kernelVersion,
|
||||||
NetworkAddresses: addrs,
|
NetworkAddresses: addrs,
|
||||||
SystemSerialNumber: si.SystemSerialNumber,
|
SystemSerialNumber: si.SystemSerialNumber,
|
||||||
SystemProductName: si.SystemProductName,
|
SystemProductName: si.SystemProductName,
|
||||||
@@ -78,18 +69,12 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
func _getInfo() string {
|
func kernelInfo() (string, string, string) {
|
||||||
cmd := exec.Command("uname", "-srio")
|
var uts unix.Utsname
|
||||||
cmd.Stdin = strings.NewReader("some")
|
if err := unix.Uname(&uts); err != nil {
|
||||||
var out bytes.Buffer
|
return "", "", ""
|
||||||
var stderr bytes.Buffer
|
|
||||||
cmd.Stdout = &out
|
|
||||||
cmd.Stderr = &stderr
|
|
||||||
err := cmd.Run()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("getInfo: %s", err)
|
|
||||||
}
|
}
|
||||||
return out.String()
|
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func sysInfo() (string, string, string) {
|
func sysInfo() (string, string, string) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
certValidationTimeout = 60 * time.Second
|
certValidationTimeout = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||||
@@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
|||||||
|
|
||||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||||
|
|
||||||
resultChan := make(chan bool)
|
resultChan := make(chan bool, 1)
|
||||||
errorChan := make(chan error)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
// Release from inside the callbacks so a post-timeout promise resolution
|
||||||
result := args[0].Bool()
|
// does not invoke an already-released func.
|
||||||
resultChan <- result
|
var thenFn, catchFn js.Func
|
||||||
|
var releaseOnce sync.Once
|
||||||
|
release := func() {
|
||||||
|
releaseOnce.Do(func() {
|
||||||
|
thenFn.Release()
|
||||||
|
catchFn.Release()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
defer release()
|
||||||
|
resultChan <- args[0].Bool()
|
||||||
return nil
|
return nil
|
||||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
})
|
||||||
|
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
defer release()
|
||||||
errorChan <- fmt.Errorf("certificate validation failed")
|
errorChan <- fmt.Errorf("certificate validation failed")
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
promise.Call("then", thenFn).Call("catch", catchFn)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case result := <-resultChan:
|
case result := <-resultChan:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -57,6 +58,8 @@ type RDCleanPathProxy struct {
|
|||||||
}
|
}
|
||||||
activeConnections map[string]*proxyConnection
|
activeConnections map[string]*proxyConnection
|
||||||
destinations map[string]string
|
destinations map[string]string
|
||||||
|
pendingHandlers map[string]js.Func
|
||||||
|
nextID atomic.Uint64
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,8 +69,15 @@ type proxyConnection struct {
|
|||||||
rdpConn net.Conn
|
rdpConn net.Conn
|
||||||
tlsConn *tls.Conn
|
tlsConn *tls.Conn
|
||||||
wsHandlers js.Value
|
wsHandlers js.Value
|
||||||
ctx context.Context
|
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||||
cancel context.CancelFunc
|
// global handle map and MUST be released, otherwise every connection
|
||||||
|
// leaks the Go memory the closure captures.
|
||||||
|
wsHandlerFn js.Func
|
||||||
|
onMessageFn js.Func
|
||||||
|
onCloseFn js.Func
|
||||||
|
cleanupOnce sync.Once
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||||
@@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProxy creates a new proxy endpoint for the given destination
|
// CreateProxy creates a new proxy endpoint for the given destination.
|
||||||
|
// The registered handler fn and its destinations/pendingHandlers entries are
|
||||||
|
// only released once a connection is established and cleanupConnection runs.
|
||||||
|
// If a caller invokes CreateProxy but never connects to the returned URL,
|
||||||
|
// those entries stay pinned for the lifetime of the page.
|
||||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||||
destination := net.JoinHostPort(hostname, port)
|
destination := net.JoinHostPort(hostname, port)
|
||||||
|
|
||||||
@@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
resolve := args[0]
|
resolve := args[0]
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
|
||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
if p.destinations == nil {
|
if p.destinations == nil {
|
||||||
@@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||||
|
|
||||||
// Register the WebSocket handler for this specific proxy
|
// Register the WebSocket handler for this specific proxy
|
||||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return js.ValueOf("error: requires WebSocket argument")
|
return js.ValueOf("error: requires WebSocket argument")
|
||||||
}
|
}
|
||||||
@@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
|||||||
ws := args[0]
|
ws := args[0]
|
||||||
p.HandleWebSocketConnection(ws, proxyID)
|
p.HandleWebSocketConnection(ws, proxyID)
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.pendingHandlers == nil {
|
||||||
|
p.pendingHandlers = make(map[string]js.Func)
|
||||||
|
}
|
||||||
|
p.pendingHandlers[proxyID] = handlerFn
|
||||||
|
p.mu.Unlock()
|
||||||
|
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
|
||||||
|
|
||||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||||
resolve.Invoke(proxyURL)
|
resolve.Invoke(proxyURL)
|
||||||
@@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
|||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
p.activeConnections[proxyID] = conn
|
p.activeConnections[proxyID] = conn
|
||||||
|
if fn, ok := p.pendingHandlers[proxyID]; ok {
|
||||||
|
conn.wsHandlerFn = fn
|
||||||
|
delete(p.pendingHandlers, proxyID)
|
||||||
|
}
|
||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
|
|
||||||
p.setupWebSocketHandlers(ws, conn)
|
p.setupWebSocketHandlers(ws, conn)
|
||||||
@@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
|
|||||||
data := args[0]
|
data := args[0]
|
||||||
go p.handleWebSocketMessage(conn, data)
|
go p.handleWebSocketMessage(conn, data)
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
ws.Set("onGoMessage", conn.onMessageFn)
|
||||||
|
|
||||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
log.Debug("WebSocket closed by JavaScript")
|
log.Debug("WebSocket closed by JavaScript")
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
return nil
|
return nil
|
||||||
}))
|
})
|
||||||
|
ws.Set("onGoClose", conn.onCloseFn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||||
@@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||||
log.Debugf("Cleaning up connection %s", conn.id)
|
conn.cleanupOnce.Do(func() {
|
||||||
conn.cancel()
|
log.Debugf("Cleaning up connection %s", conn.id)
|
||||||
if conn.tlsConn != nil {
|
conn.cancel()
|
||||||
log.Debug("Closing TLS connection")
|
if conn.tlsConn != nil {
|
||||||
if err := conn.tlsConn.Close(); err != nil {
|
log.Debug("Closing TLS connection")
|
||||||
log.Debugf("Error closing TLS connection: %v", err)
|
if err := conn.tlsConn.Close(); err != nil {
|
||||||
|
log.Debugf("Error closing TLS connection: %v", err)
|
||||||
|
}
|
||||||
|
conn.tlsConn = nil
|
||||||
}
|
}
|
||||||
conn.tlsConn = nil
|
if conn.rdpConn != nil {
|
||||||
}
|
log.Debug("Closing TCP connection")
|
||||||
if conn.rdpConn != nil {
|
if err := conn.rdpConn.Close(); err != nil {
|
||||||
log.Debug("Closing TCP connection")
|
log.Debugf("Error closing TCP connection: %v", err)
|
||||||
if err := conn.rdpConn.Close(); err != nil {
|
}
|
||||||
log.Debugf("Error closing TCP connection: %v", err)
|
conn.rdpConn = nil
|
||||||
}
|
}
|
||||||
conn.rdpConn = nil
|
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
|
||||||
}
|
|
||||||
p.mu.Lock()
|
// Detach before releasing so late JS calls surface as TypeError instead
|
||||||
delete(p.activeConnections, conn.id)
|
// of silent "call to released function".
|
||||||
p.mu.Unlock()
|
if conn.wsHandlers.Truthy() {
|
||||||
|
conn.wsHandlers.Set("onGoMessage", js.Undefined())
|
||||||
|
conn.wsHandlers.Set("onGoClose", js.Undefined())
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsHandlerFn may be zero-value if the pending handler lookup missed.
|
||||||
|
if conn.wsHandlerFn.Truthy() {
|
||||||
|
conn.wsHandlerFn.Release()
|
||||||
|
}
|
||||||
|
if conn.onMessageFn.Truthy() {
|
||||||
|
conn.onMessageFn.Release()
|
||||||
|
}
|
||||||
|
if conn.onCloseFn.Truthy() {
|
||||||
|
conn.onCloseFn.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.activeConnections, conn.id)
|
||||||
|
delete(p.destinations, conn.id)
|
||||||
|
delete(p.pendingHandlers, conn.id)
|
||||||
|
p.mu.Unlock()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
func CreateJSInterface(client *Client) js.Value {
|
func CreateJSInterface(client *Client) js.Value {
|
||||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||||
|
|
||||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 1 {
|
if len(args) < 1 {
|
||||||
return js.ValueOf(false)
|
return js.ValueOf(false)
|
||||||
}
|
}
|
||||||
@@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value {
|
|||||||
|
|
||||||
_, err := client.Write(bytes)
|
_, err := client.Write(bytes)
|
||||||
return js.ValueOf(err == nil)
|
return js.ValueOf(err == nil)
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("write", writeFunc)
|
||||||
|
|
||||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return js.ValueOf(false)
|
return js.ValueOf(false)
|
||||||
}
|
}
|
||||||
@@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value {
|
|||||||
rows := args[1].Int()
|
rows := args[1].Int()
|
||||||
err := client.Resize(cols, rows)
|
err := client.Resize(cols, rows)
|
||||||
return js.ValueOf(err == nil)
|
return js.ValueOf(err == nil)
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("resize", resizeFunc)
|
||||||
|
|
||||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
client.Close()
|
client.Close()
|
||||||
return js.Undefined()
|
return js.Undefined()
|
||||||
}))
|
})
|
||||||
|
jsInterface.Set("close", closeFunc)
|
||||||
|
|
||||||
go readLoop(client, jsInterface)
|
go func() {
|
||||||
|
readLoop(client, jsInterface)
|
||||||
|
// Detach before releasing so late JS calls surface as TypeError instead
|
||||||
|
// of silent "call to released function".
|
||||||
|
jsInterface.Set("write", js.Undefined())
|
||||||
|
jsInterface.Set("resize", js.Undefined())
|
||||||
|
jsInterface.Set("close", js.Undefined())
|
||||||
|
writeFunc.Release()
|
||||||
|
resizeFunc.Release()
|
||||||
|
closeFunc.Release()
|
||||||
|
}()
|
||||||
|
|
||||||
return jsInterface
|
return jsInterface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
|||||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||||
if servers.relaySrv != nil {
|
if servers.relaySrv != nil {
|
||||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||||
}
|
}
|
||||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
var relayAcceptFn func(conn listener.Conn)
|
var relayAcceptFn func(conn listener.Conn)
|
||||||
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
|||||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embedded IdP (Dex)
|
||||||
|
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||||
|
idpHandler.ServeHTTP(w, r)
|
||||||
|
|
||||||
// Management HTTP API (default)
|
// Management HTTP API (default)
|
||||||
default:
|
default:
|
||||||
httpHandler.ServeHTTP(w, r)
|
httpHandler.ServeHTTP(w, r)
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ type NameServerGroup struct {
|
|||||||
ID string `gorm:"primaryKey"`
|
ID string `gorm:"primaryKey"`
|
||||||
// AccountID is a reference to Account that this object belongs
|
// AccountID is a reference to Account that this object belongs
|
||||||
AccountID string `gorm:"index"`
|
AccountID string `gorm:"index"`
|
||||||
|
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||||
|
// compact wire id when sending NetworkMap components to capable peers.
|
||||||
|
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||||
// Name group name
|
// Name group name
|
||||||
Name string
|
Name string
|
||||||
// Description group description
|
// Description group description
|
||||||
|
|||||||
12
go.mod
12
go.mod
@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
|
|||||||
go 1.25.5
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cunicu.li/go-rosenpass v0.4.0
|
cunicu.li/go-rosenpass v0.5.42
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0
|
github.com/cenkalti/backoff/v4 v4.3.0
|
||||||
github.com/cloudflare/circl v1.3.3 // indirect
|
github.com/cloudflare/circl v1.3.3 // indirect
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
@@ -19,8 +19,8 @@ require (
|
|||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.50.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/sys v0.43.0
|
golang.org/x/sys v0.43.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/grpc v1.80.0
|
google.golang.org/grpc v1.80.0
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
@@ -38,7 +38,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||||
github.com/c-robinson/iplib v1.0.3
|
github.com/c-robinson/iplib v1.0.3
|
||||||
github.com/caddyserver/certmagic v0.21.3
|
github.com/caddyserver/certmagic v0.21.3
|
||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.19.0
|
||||||
github.com/coder/websocket v1.8.14
|
github.com/coder/websocket v1.8.14
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/coreos/go-oidc/v3 v3.18.0
|
github.com/coreos/go-oidc/v3 v3.18.0
|
||||||
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/google/go-cmp v0.7.0
|
github.com/google/go-cmp v0.7.0
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/nftables v0.3.0
|
github.com/google/nftables v0.3.0
|
||||||
github.com/gopacket/gopacket v1.1.1
|
github.com/gopacket/gopacket v1.4.0
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
|||||||
26
go.sum
26
go.sum
@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
|
|||||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
||||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
|
||||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
|
||||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||||
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
|
|||||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
||||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
|
|||||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
|
||||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
|
|||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||||
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
|
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
|
||||||
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
|
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
|
||||||
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
||||||
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
||||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||||
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
|
|||||||
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
|
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
|
||||||
|
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
|
||||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
||||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||||
@@ -499,8 +501,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||||
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
|||||||
if file == "" {
|
if file == "" {
|
||||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||||
}
|
}
|
||||||
return (&sql.SQLite3{File: file}).Open(logger)
|
return newSQLite3(file).Open(logger)
|
||||||
case "postgres":
|
case "postgres":
|
||||||
dsn, _ := s.Config["dsn"].(string)
|
dsn, _ := s.Config["dsn"].(string)
|
||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/dexidp/dex/server"
|
"github.com/dexidp/dex/server"
|
||||||
"github.com/dexidp/dex/server/signer"
|
"github.com/dexidp/dex/server/signer"
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/dexidp/dex/storage/sql"
|
|
||||||
jose "github.com/go-jose/go-jose/v4"
|
jose "github.com/go-jose/go-jose/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
@@ -77,7 +76,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
|||||||
|
|
||||||
// Initialize SQLite storage
|
// Initialize SQLite storage
|
||||||
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
||||||
sqliteConfig := &sql.SQLite3{File: dbPath}
|
sqliteConfig := newSQLite3(dbPath)
|
||||||
stor, err := sqliteConfig.Open(logger)
|
stor, err := sqliteConfig.Open(logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open storage: %w", err)
|
return nil, fmt.Errorf("failed to open storage: %w", err)
|
||||||
|
|||||||
15
idp/dex/sqlite_cgo.go
Normal file
15
idp/dex/sqlite_cgo.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build cgo
|
||||||
|
|
||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
sql "github.com/dexidp/dex/storage/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
|
||||||
|
// struct that takes a File path. Non-CGO builds get an empty stub whose
|
||||||
|
// Open() returns the dex "SQLite not available" error — correct behaviour
|
||||||
|
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
|
||||||
|
func newSQLite3(file string) *sql.SQLite3 {
|
||||||
|
return &sql.SQLite3{File: file}
|
||||||
|
}
|
||||||
15
idp/dex/sqlite_nocgo.go
Normal file
15
idp/dex/sqlite_nocgo.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !cgo
|
||||||
|
|
||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
sql "github.com/dexidp/dex/storage/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
|
||||||
|
// Open() returns an error documenting the missing CGO support — correct
|
||||||
|
// behaviour for cross-compiled artefacts that never actually run the
|
||||||
|
// embedded IdP. The `file` argument is ignored.
|
||||||
|
func newSQLite3(_ string) *sql.SQLite3 {
|
||||||
|
return &sql.SQLite3{}
|
||||||
|
}
|
||||||
@@ -55,6 +55,12 @@ type Controller struct {
|
|||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
|
||||||
|
// componentsDisabled, when true, forces the controller to emit legacy
|
||||||
|
// proto.NetworkMap to every peer regardless of capability. Set once at
|
||||||
|
// construction and never written after — readers race-free without a
|
||||||
|
// mutex.
|
||||||
|
componentsDisabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -81,12 +87,27 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
config: config,
|
config: config,
|
||||||
|
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
|
||||||
|
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
EphemeralPeersManager: ephemeralPeersManager,
|
EphemeralPeersManager: ephemeralPeersManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerNeedsComponents reports whether the gRPC layer should emit the
|
||||||
|
// component-based wire format for this peer.
|
||||||
|
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
|
||||||
|
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
|
||||||
|
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
|
||||||
|
// literal.
|
||||||
|
func parseBoolEnv(key string) bool {
|
||||||
|
v, _ := strconv.ParseBool(os.Getenv(key))
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -112,7 +133,7 @@ func (c *Controller) CountStreams() int {
|
|||||||
return c.peersUpdateManager.CountStreams()
|
return c.peersUpdateManager.CountStreams()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -175,6 +196,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.accountManagerMetrics != nil {
|
||||||
|
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
semaphore <- struct{}{}
|
semaphore <- struct{}{}
|
||||||
go func(p *nbpeer.Peer) {
|
go func(p *nbpeer.Peer) {
|
||||||
@@ -192,18 +217,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap := proxyNetworkMaps[p.ID]
|
||||||
if ok {
|
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
result.NetworkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerGroups := account.GetPeerGroups(p.ID)
|
peerGroups := account.GetPeerGroups(p.ID)
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
var update *proto.SyncResponse
|
||||||
|
if result.IsComponents() {
|
||||||
|
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
|
||||||
|
// the client merges it into Calculate()'s output the same
|
||||||
|
// way the legacy server did via NetworkMap.Merge.
|
||||||
|
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
} else {
|
||||||
|
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
}
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||||
@@ -242,14 +275,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -265,7 +298,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
|||||||
if c.accountManagerMetrics != nil {
|
if c.accountManagerMetrics != nil {
|
||||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
}
|
}
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||||
@@ -314,11 +347,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
result.NetworkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||||
@@ -329,7 +362,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
peerGroups := account.GetPeerGroups(peerId)
|
peerGroups := account.GetPeerGroups(peerId)
|
||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
var update *proto.SyncResponse
|
||||||
|
if result.IsComponents() {
|
||||||
|
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
} else {
|
||||||
|
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
|
}
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||||
Update: update,
|
Update: update,
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
@@ -359,14 +397,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -376,6 +414,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidatedPeerWithComponents is the components-format counterpart of
|
||||||
|
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
|
||||||
|
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
|
||||||
|
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
|
||||||
|
// encodes both into the wire envelope. Callers must gate on capability
|
||||||
|
// themselves before dispatching here — this method does NOT branch on it.
|
||||||
|
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
|
if isRequiresApproval {
|
||||||
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
|
||||||
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch the proxy network map fragment for this peer alongside the
|
||||||
|
// components — same single-account-load path the streaming controller
|
||||||
|
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
|
||||||
|
// instead of waiting for the next streaming push.
|
||||||
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||||
|
return nil, nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
|
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
|
||||||
|
|
||||||
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
|
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
if isRequiresApproval {
|
if isRequiresApproval {
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ type Controller interface {
|
|||||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
|
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
|
// PeerNeedsComponents combines the peer's advertised capability with the
|
||||||
|
// kill-switch flag — the only public predicate gRPC layers should ask.
|
||||||
|
PeerNeedsComponents(p *nbpeer.Peer) bool
|
||||||
GetDNSDomain(settings *types.Settings) string
|
GetDNSDomain(settings *types.Settings) string
|
||||||
StartWarmup(context.Context)
|
StartWarmup(context.Context)
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
|
|||||||
@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidatedPeerWithComponents mocks base method.
|
||||||
|
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
|
||||||
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
|
ret1, _ := ret[1].(*types.NetworkMapComponents)
|
||||||
|
ret2, _ := ret[2].(*types.NetworkMap)
|
||||||
|
ret3, _ := ret[3].([]*posture.Checks)
|
||||||
|
ret4, _ := ret[4].(int64)
|
||||||
|
ret5, _ := ret[5].(error)
|
||||||
|
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
|
||||||
|
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerNeedsComponents mocks base method.
|
||||||
|
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
|
||||||
|
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
|
||||||
|
}
|
||||||
|
|
||||||
// OnPeerConnected mocks base method.
|
// OnPeerConnected mocks base method.
|
||||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
|||||||
found = true
|
found = true
|
||||||
select {
|
select {
|
||||||
case channel <- update:
|
case channel <- update:
|
||||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||||
default:
|
default:
|
||||||
dropped = true
|
dropped = true
|
||||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package peers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
@@ -35,6 +36,14 @@ type Manager interface {
|
|||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||||
|
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
|
||||||
|
// Returns nil with an error when no match exists. No permission check;
|
||||||
|
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
|
||||||
|
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
|
||||||
|
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
|
||||||
|
// to. Used by the proxy's auth path to authorise a request by the calling
|
||||||
|
// peer's group memberships.
|
||||||
|
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
|
|||||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerByTunnelIP delegates to the store's indexed lookup.
|
||||||
|
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||||
|
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerWithGroups returns the peer plus its group memberships. Any store
|
||||||
|
// error returns (nil, nil, err) so callers never receive a valid peer
|
||||||
|
// alongside a non-nil error.
|
||||||
|
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||||
|
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return p, groups, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package peers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
|
net "net"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
account "github.com/netbirdio/netbird/management/server/account"
|
account "github.com/netbirdio/netbird/management/server/account"
|
||||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
types "github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockManager is a mock of Manager interface.
|
// MockManager is a mock of Manager interface.
|
||||||
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
// DeletePeers mocks base method.
|
// DeletePeers mocks base method.
|
||||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerByTunnelIP mocks base method.
|
||||||
|
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
|
||||||
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
|
||||||
|
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeerID mocks base method.
|
// GetPeerID mocks base method.
|
||||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerWithGroups mocks base method.
|
||||||
|
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
|
||||||
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
|
ret1, _ := ret[1].([]*types.Group)
|
||||||
|
ret2, _ := ret[2].(error)
|
||||||
|
return ret0, ret1, ret2
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
|
||||||
|
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeersByGroupIDs mocks base method.
|
// GetPeersByGroupIDs mocks base method.
|
||||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProxyPeer mocks base method.
|
|
||||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
|
||||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ type Domain struct {
|
|||||||
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
||||||
// Not persisted.
|
// Not persisted.
|
||||||
SupportsCrowdSec *bool `gorm:"-"`
|
SupportsCrowdSec *bool `gorm:"-"`
|
||||||
|
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
|
||||||
|
SupportsPrivate *bool `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event metadata for a domain
|
// EventMeta returns activity event metadata for a domain
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
|||||||
SupportsCustomPorts: d.SupportsCustomPorts,
|
SupportsCustomPorts: d.SupportsCustomPorts,
|
||||||
RequireSubdomain: d.RequireSubdomain,
|
RequireSubdomain: d.RequireSubdomain,
|
||||||
SupportsCrowdsec: d.SupportsCrowdSec,
|
SupportsCrowdsec: d.SupportsCrowdSec,
|
||||||
|
SupportsPrivate: d.SupportsPrivate,
|
||||||
}
|
}
|
||||||
if d.TargetCluster != "" {
|
if d.TargetCluster != "" {
|
||||||
resp.TargetCluster = &d.TargetCluster
|
resp.TargetCluster = &d.TargetCluster
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type proxyManager interface {
|
|||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
|||||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||||
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
||||||
|
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
|
||||||
ret = append(ret, d)
|
ret = append(ret, d)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
|||||||
if d.TargetCluster != "" {
|
if d.TargetCluster != "" {
|
||||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||||
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
||||||
|
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
|
||||||
}
|
}
|
||||||
// Custom domains never require a subdomain by default since
|
// Custom domains never require a subdomain by default since
|
||||||
// the account owns them and should be able to use the bare domain.
|
// the account owns them and should be able to use the bare domain.
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type mockProxyManager struct {
|
type mockProxyManager struct {
|
||||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||||
pm := &mockProxyManager{
|
pm := &mockProxyManager{
|
||||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||||
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type Manager interface {
|
|||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type store interface {
|
|||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||||
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
|||||||
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
|
||||||
|
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||||
|
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||||
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,16 +15,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type mockStore struct {
|
type mockStore struct {
|
||||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||||
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
|
|||||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newTestManager(s store) *Manager {
|
func newTestManager(s store) *Manager {
|
||||||
meter := noop.NewMeterProvider().Meter("test")
|
meter := noop.NewMeterProvider().Meter("test")
|
||||||
|
|||||||
@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClusterSupportsPrivate mocks base method.
|
||||||
|
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
|
||||||
|
ret0, _ := ret[0].(*bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
|
||||||
|
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// Connect mocks base method.
|
// Connect mocks base method.
|
||||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ type Capabilities struct {
|
|||||||
RequireSubdomain *bool
|
RequireSubdomain *bool
|
||||||
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
||||||
SupportsCrowdsec *bool
|
SupportsCrowdsec *bool
|
||||||
|
// Private indicates whether this proxy supports inbound access via Wireguard
|
||||||
|
// tunnel and netbird-only authentication policies
|
||||||
|
Private *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy represents a reverse proxy instance
|
// Proxy represents a reverse proxy instance
|
||||||
@@ -67,10 +70,9 @@ type Cluster struct {
|
|||||||
Type ClusterType
|
Type ClusterType
|
||||||
Online bool
|
Online bool
|
||||||
ConnectedProxies int
|
ConnectedProxies int
|
||||||
// Capability flags. *bool because nil means "no proxy reported a
|
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
|
||||||
// capability for this cluster" — the dashboard renders these as
|
|
||||||
// unknown rather than false.
|
|
||||||
SupportsCustomPorts *bool
|
SupportsCustomPorts *bool
|
||||||
RequireSubdomain *bool
|
RequireSubdomain *bool
|
||||||
SupportsCrowdSec *bool
|
SupportsCrowdSec *bool
|
||||||
|
Private *bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
|||||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||||
RequireSubdomain: c.RequireSubdomain,
|
RequireSubdomain: c.RequireSubdomain,
|
||||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||||
|
Private: c.Private,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ type CapabilityProvider interface {
|
|||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
|
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -136,6 +137,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
|
|||||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||||
|
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
|
||||||
}
|
}
|
||||||
|
|
||||||
return clusters, nil
|
return clusters, nil
|
||||||
@@ -208,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
|||||||
target.Host = resource.Domain
|
target.Host = resource.Domain
|
||||||
case service.TargetTypeSubnet:
|
case service.TargetTypeSubnet:
|
||||||
// For subnets we do not do any lookups on the resource
|
// For subnets we do not do any lookups on the resource
|
||||||
|
case service.TargetTypeCluster:
|
||||||
|
// Cluster targets carry the upstream address on target_id; the
|
||||||
|
// proxy resolves the destination at request time.
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
}
|
}
|
||||||
@@ -779,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
|||||||
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case service.TargetTypeCluster:
|
||||||
|
if err := validateClusterTarget(target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||||
}
|
}
|
||||||
@@ -786,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateClusterTarget(target *service.Target) error {
|
||||||
|
if !target.Options.DirectUpstream {
|
||||||
|
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
||||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
@@ -962,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
|||||||
return fmt.Errorf("failed to get services: %w", err)
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||||
|
|
||||||
for _, s := range services {
|
for _, s := range services {
|
||||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
}
|
}
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
// No peer or resource lookups must be issued for cluster targets.
|
||||||
|
targets := []*rpservice.Target{
|
||||||
|
{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: rpservice.TargetTypeCluster,
|
||||||
|
Options: rpservice.TargetOptions{DirectUpstream: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
|
||||||
|
// store-side check that cluster targets must opt into the host-stack dial
|
||||||
|
// path. Without DirectUpstream the proxy would route this target through
|
||||||
|
// the embedded NetBird client and fail on every request.
|
||||||
|
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
targets := []*rpservice.Target{
|
||||||
|
{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: rpservice.TargetTypeCluster,
|
||||||
|
Host: "backend.lan",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := validateTargetReferences(ctx, mockStore, accountID, targets)
|
||||||
|
require.Error(t, err, "cluster target without direct_upstream must be rejected")
|
||||||
|
assert.ErrorContains(t, err, "direct upstream disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
accountID := "test-account"
|
||||||
|
|
||||||
|
mgr := &Manager{store: mockStore}
|
||||||
|
|
||||||
|
svc := &rpservice.Service{
|
||||||
|
ID: "svc-1",
|
||||||
|
AccountID: accountID,
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: rpservice.TargetTypeCluster,
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
|
||||||
|
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,10 +45,11 @@ const (
|
|||||||
StatusCertificateFailed Status = "certificate_failed"
|
StatusCertificateFailed Status = "certificate_failed"
|
||||||
StatusError Status = "error"
|
StatusError Status = "error"
|
||||||
|
|
||||||
TargetTypePeer TargetType = "peer"
|
TargetTypePeer TargetType = "peer"
|
||||||
TargetTypeHost TargetType = "host"
|
TargetTypeHost TargetType = "host"
|
||||||
TargetTypeDomain TargetType = "domain"
|
TargetTypeDomain TargetType = "domain"
|
||||||
TargetTypeSubnet TargetType = "subnet"
|
TargetTypeSubnet TargetType = "subnet"
|
||||||
|
TargetTypeCluster TargetType = "cluster"
|
||||||
|
|
||||||
SourcePermanent = "permanent"
|
SourcePermanent = "permanent"
|
||||||
SourceEphemeral = "ephemeral"
|
SourceEphemeral = "ephemeral"
|
||||||
@@ -60,6 +61,11 @@ type TargetOptions struct {
|
|||||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||||
|
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
|
||||||
|
// the target via the proxy host's network stack. Useful for upstreams
|
||||||
|
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||||
|
// sidecars). Default false.
|
||||||
|
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Target struct {
|
type Target struct {
|
||||||
@@ -67,7 +73,7 @@ type Target struct {
|
|||||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
Path *string `json:"path,omitempty"`
|
Path *string `json:"path,omitempty"`
|
||||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
Host string `json:"host"`
|
||||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
@@ -200,6 +206,10 @@ type Service struct {
|
|||||||
Mode string `gorm:"default:'http'"`
|
Mode string `gorm:"default:'http'"`
|
||||||
ListenPort uint16
|
ListenPort uint16
|
||||||
PortAutoAssigned bool
|
PortAutoAssigned bool
|
||||||
|
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
|
||||||
|
Private bool
|
||||||
|
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
|
||||||
|
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||||
@@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service {
|
|||||||
Mode: &mode,
|
Mode: &mode,
|
||||||
ListenPort: &listenPort,
|
ListenPort: &listenPort,
|
||||||
PortAutoAssigned: &s.PortAutoAssigned,
|
PortAutoAssigned: &s.PortAutoAssigned,
|
||||||
|
Private: &s.Private,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.AccessGroups) > 0 {
|
||||||
|
groups := append([]string(nil), s.AccessGroups...)
|
||||||
|
resp.AccessGroups = &groups
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ProxyCluster != "" {
|
if s.ProxyCluster != "" {
|
||||||
@@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToProtoMapping converts the service into the wire format the proxy consumes.
|
||||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||||
pathMappings := s.buildPathMappings()
|
pathMappings := s.buildPathMappings()
|
||||||
|
|
||||||
@@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
|||||||
RewriteRedirects: s.RewriteRedirects,
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
Mode: s.Mode,
|
Mode: s.Mode,
|
||||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||||
|
Private: s.Private,
|
||||||
}
|
}
|
||||||
|
|
||||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||||
@@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
|
||||||
|
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
apiOpts := &api.ServiceTargetOptions{}
|
apiOpts := &api.ServiceTargetOptions{}
|
||||||
@@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
|||||||
if len(opts.CustomHeaders) > 0 {
|
if len(opts.CustomHeaders) > 0 {
|
||||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||||
}
|
}
|
||||||
|
if opts.DirectUpstream {
|
||||||
|
apiOpts.DirectUpstream = &opts.DirectUpstream
|
||||||
|
}
|
||||||
return apiOpts
|
return apiOpts
|
||||||
}
|
}
|
||||||
|
|
||||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||||
|
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
popts := &proto.PathTargetOptions{
|
popts := &proto.PathTargetOptions{
|
||||||
SkipTlsVerify: opts.SkipTLSVerify,
|
SkipTlsVerify: opts.SkipTLSVerify,
|
||||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||||
CustomHeaders: opts.CustomHeaders,
|
CustomHeaders: opts.CustomHeaders,
|
||||||
|
DirectUpstream: opts.DirectUpstream,
|
||||||
}
|
}
|
||||||
if opts.RequestTimeout != 0 {
|
if opts.RequestTimeout != 0 {
|
||||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||||
@@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
|||||||
if o.CustomHeaders != nil {
|
if o.CustomHeaders != nil {
|
||||||
opts.CustomHeaders = *o.CustomHeaders
|
opts.CustomHeaders = *o.CustomHeaders
|
||||||
}
|
}
|
||||||
|
if o.DirectUpstream != nil {
|
||||||
|
opts.DirectUpstream = *o.DirectUpstream
|
||||||
|
}
|
||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
|||||||
if req.ListenPort != nil {
|
if req.ListenPort != nil {
|
||||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||||
}
|
}
|
||||||
|
if req.Private != nil {
|
||||||
|
s.Private = *req.Private
|
||||||
|
}
|
||||||
|
if req.AccessGroups != nil {
|
||||||
|
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
|
||||||
|
} else {
|
||||||
|
s.AccessGroups = nil
|
||||||
|
}
|
||||||
|
|
||||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -740,6 +775,9 @@ func (s *Service) Validate() error {
|
|||||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := s.validatePrivateRequirements(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
switch s.Mode {
|
switch s.Mode {
|
||||||
case ModeHTTP:
|
case ModeHTTP:
|
||||||
@@ -753,6 +791,23 @@ func (s *Service) Validate() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
|
||||||
|
func (s *Service) validatePrivateRequirements() error {
|
||||||
|
if !s.Private {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.Mode != "" && s.Mode != ModeHTTP {
|
||||||
|
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
|
||||||
|
}
|
||||||
|
if len(s.AccessGroups) == 0 {
|
||||||
|
return errors.New("private services require at least one access group")
|
||||||
|
}
|
||||||
|
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||||
|
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) validateHTTPMode() error {
|
func (s *Service) validateHTTPMode() error {
|
||||||
if s.Domain == "" {
|
if s.Domain == "" {
|
||||||
return errors.New("service domain is required")
|
return errors.New("service domain is required")
|
||||||
@@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error {
|
|||||||
for i, target := range s.Targets {
|
for i, target := range s.Targets {
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
// host field will be ignored
|
// Host is normally overwritten by replaceHostByLookup with the
|
||||||
|
// resolved peer IP / resource address; operator-supplied values
|
||||||
|
// are honored only when DirectUpstream is set. Validate the
|
||||||
|
// override here so misconfigured hosts fail fast at API time.
|
||||||
|
if err := validateDirectUpstreamHost(i, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case TargetTypeSubnet:
|
case TargetTypeSubnet:
|
||||||
if target.Host == "" {
|
if target.Host == "" {
|
||||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||||
}
|
}
|
||||||
|
case TargetTypeCluster:
|
||||||
|
if err := validateClusterTarget(i, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||||
}
|
}
|
||||||
@@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
|
||||||
|
func validateClusterTarget(idx int, target *Target) error {
|
||||||
|
host := strings.TrimSpace(target.Host)
|
||||||
|
if host == "" {
|
||||||
|
return fmt.Errorf("target %d: has empty host", idx)
|
||||||
|
}
|
||||||
|
if !target.Options.DirectUpstream {
|
||||||
|
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
|
||||||
|
}
|
||||||
|
return validateDirectUpstreamHost(idx, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDirectUpstreamHost validates the operator-supplied Host on a
|
||||||
|
// peer/host/domain target when DirectUpstream is set. Empty Host is
|
||||||
|
// allowed — the lookup fills in the default peer IP / resource address.
|
||||||
|
// Without DirectUpstream the Host value is silently overwritten by
|
||||||
|
// replaceHostByLookup, so we don't validate it (preserves the historical
|
||||||
|
// behaviour where APIs accepted any value and dropped it). Non-empty
|
||||||
|
// Host with DirectUpstream must look like a hostname or IP and must
|
||||||
|
// not carry a port (port lives on Target.Port).
|
||||||
|
func validateDirectUpstreamHost(idx int, target *Target) error {
|
||||||
|
if !target.Options.DirectUpstream {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
host := strings.TrimSpace(target.Host)
|
||||||
|
if host == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(host, " \t/") {
|
||||||
|
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
|
||||||
|
}
|
||||||
|
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) validateL4Target(target *Target) error {
|
func (s *Service) validateL4Target(target *Target) error {
|
||||||
// L4 services have a single target; per-target disable is meaningless
|
// L4 services have a single target; per-target disable is meaningless
|
||||||
// (use the service-level Enabled flag instead). Force it on so that
|
// (use the service-level Enabled flag instead). Force it on so that
|
||||||
// buildPathMappings always includes the target in the proto.
|
// buildPathMappings always includes the target in the proto.
|
||||||
target.Enabled = true
|
target.Enabled = true
|
||||||
|
|
||||||
if target.Port == 0 {
|
|
||||||
return errors.New("target port is required for L4 services")
|
|
||||||
}
|
|
||||||
if target.TargetId == "" {
|
if target.TargetId == "" {
|
||||||
return errors.New("target_id is required for L4 services")
|
return errors.New("target_id is required for L4 services")
|
||||||
}
|
}
|
||||||
|
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||||
|
return errors.New("target port is required for L4 services")
|
||||||
|
}
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
// OK
|
if err := validateDirectUpstreamHost(0, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case TargetTypeSubnet:
|
case TargetTypeSubnet:
|
||||||
if target.Host == "" {
|
if target.Host == "" {
|
||||||
return errors.New("target host is required for subnet targets")
|
return errors.New("target host is required for subnet targets")
|
||||||
}
|
}
|
||||||
|
case TargetTypeCluster:
|
||||||
|
// target_id carries the cluster address; the proxy resolves
|
||||||
|
// the upstream at request time.
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||||
}
|
}
|
||||||
@@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var accessGroups []string
|
||||||
|
if len(s.AccessGroups) > 0 {
|
||||||
|
accessGroups = append([]string(nil), s.AccessGroups...)
|
||||||
|
}
|
||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
ID: s.ID,
|
ID: s.ID,
|
||||||
AccountID: s.AccountID,
|
AccountID: s.AccountID,
|
||||||
@@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service {
|
|||||||
Mode: s.Mode,
|
Mode: s.Mode,
|
||||||
ListenPort: s.ListenPort,
|
ListenPort: s.ListenPort,
|
||||||
PortAutoAssigned: s.PortAutoAssigned,
|
PortAutoAssigned: s.PortAutoAssigned,
|
||||||
|
Private: s.Private,
|
||||||
|
AccessGroups: accessGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) {
|
|||||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidate_HTTPClusterTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
|
||||||
|
// rule that operator-supplied Host is mandatory: cluster targets dial the
|
||||||
|
// upstream via the host network stack (direct_upstream is implied), so an
|
||||||
|
// empty Host leaves the proxy with nothing to dial.
|
||||||
|
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
|
||||||
|
// half of the cluster-target rule: DirectUpstream must be true so the
|
||||||
|
// stdlib transport branch in MultiTransport is taken. Without it the
|
||||||
|
// embedded NetBird client would try to dial the cluster address through
|
||||||
|
// the WG tunnel, which is the wrong network for a cluster upstream.
|
||||||
|
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Mode = ModeTCP
|
||||||
|
rp.ListenPort = 9000
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "tcp",
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||||
|
svc := validProxy()
|
||||||
|
svc.Private = true
|
||||||
|
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
|
||||||
|
cp := svc.Copy()
|
||||||
|
require.NotNil(t, cp)
|
||||||
|
assert.True(t, cp.Private)
|
||||||
|
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
|
||||||
|
|
||||||
|
cp.Private = false
|
||||||
|
assert.True(t, svc.Private)
|
||||||
|
|
||||||
|
cp.AccessGroups[0] = "grp-other"
|
||||||
|
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_APIRoundtrip_Private(t *testing.T) {
|
||||||
|
enabled := true
|
||||||
|
private := true
|
||||||
|
accessGroups := []string{"grp-admins"}
|
||||||
|
targets := []api.ServiceTarget{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: api.ServiceTargetTargetType("cluster"),
|
||||||
|
Protocol: "http",
|
||||||
|
Port: 80,
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
req := &api.ServiceRequest{
|
||||||
|
Name: "svc-private",
|
||||||
|
Domain: "myapp.eu.proxy.netbird.io",
|
||||||
|
Enabled: enabled,
|
||||||
|
Private: &private,
|
||||||
|
AccessGroups: &accessGroups,
|
||||||
|
Targets: &targets,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &Service{}
|
||||||
|
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
|
||||||
|
assert.True(t, svc.Private)
|
||||||
|
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
|
||||||
|
|
||||||
|
resp := svc.ToAPIResponse()
|
||||||
|
require.NotNil(t, resp.Private)
|
||||||
|
assert.True(t, *resp.Private)
|
||||||
|
require.NotNil(t, resp.AccessGroups)
|
||||||
|
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "access group")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
rp.Auth.BearerAuth = &BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"grp-sso"},
|
||||||
|
}
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "backend.lan",
|
||||||
|
Options: TargetOptions{DirectUpstream: true},
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Private = true
|
||||||
|
rp.AccessGroups = []string{"grp-admins"}
|
||||||
|
rp.Mode = ModeTCP
|
||||||
|
rp.Targets = []*Target{{
|
||||||
|
TargetId: "eu.proxy.netbird.io",
|
||||||
|
TargetType: TargetTypeCluster,
|
||||||
|
Protocol: "tcp",
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "HTTP")
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,20 @@ type KeyPair struct {
|
|||||||
type Claims struct {
|
type Claims struct {
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
Method auth.Method `json:"method"`
|
Method auth.Method `json:"method"`
|
||||||
|
// Email is the calling user's email address. Carried so the
|
||||||
|
// proxy can stamp identity on upstream requests (e.g.
|
||||||
|
// x-litellm-end-user-id) without an extra management
|
||||||
|
// round-trip on every cookie-bearing request.
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
// Groups carries the user's group IDs so the proxy can stamp them
|
||||||
|
// onto upstream requests (X-NetBird-Groups) from the cookie path
|
||||||
|
// without an extra management round-trip.
|
||||||
|
Groups []string `json:"groups,omitempty"`
|
||||||
|
// GroupNames carries the human-readable display names for the ids
|
||||||
|
// in Groups, ordered identically (positional pairing). Slice may be
|
||||||
|
// shorter than Groups for tokens minted before names were
|
||||||
|
// resolvable; the consumer falls back to ids for missing positions.
|
||||||
|
GroupNames []string `json:"group_names,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateKeyPair() (*KeyPair, error) {
|
func GenerateKeyPair() (*KeyPair, error) {
|
||||||
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
// SignToken mints a session JWT for the given user and domain. email,
|
||||||
|
// groups, and groupNames, when non-empty, are embedded so the proxy can
|
||||||
|
// authorise and stamp identity for policy-aware middlewares without a
|
||||||
|
// management round-trip on every cookie-bearing request. groupNames
|
||||||
|
// pairs positionally with groups; pass nil when names couldn't be
|
||||||
|
// resolved.
|
||||||
|
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
|
||||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("decode private key: %w", err)
|
return "", fmt.Errorf("decode private key: %w", err)
|
||||||
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
|
|||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
},
|
},
|
||||||
Method: method,
|
Method: method,
|
||||||
|
Email: email,
|
||||||
|
Groups: append([]string(nil), groups...),
|
||||||
|
GroupNames: append([]string(nil), groupNames...),
|
||||||
}
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
"github.com/rs/cors"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -19,7 +21,6 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
@@ -27,16 +28,20 @@ import (
|
|||||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util/crypt"
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const apiPrefix = "/api"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
kaep = keepalive.EnforcementPolicy{
|
kaep = keepalive.EnforcementPolicy{
|
||||||
MinTime: 15 * time.Second,
|
MinTime: 15 * time.Second,
|
||||||
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) EventStore() activity.Store {
|
func (s *BaseServer) EventStore() activity.Store {
|
||||||
return Create(s, func() activity.Store {
|
return Create(s, func() activity.Store {
|
||||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
var err error
|
||||||
if err != nil {
|
key := s.Config.DataStoreEncryptionKey
|
||||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
if key == "" {
|
||||||
|
log.Debugf("generate new activity store encryption key")
|
||||||
|
key, err = crypt.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize event store: %v", err)
|
log.Fatalf("failed to initialize event store: %v", err)
|
||||||
}
|
}
|
||||||
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||||
|
// the deployment isn't using the embedded variant.
|
||||||
|
func (s *BaseServer) IDPHandler() http.Handler {
|
||||||
|
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||||
|
if !ok || embeddedIdP == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) Router() *mux.Router {
|
||||||
|
return Create(s, func() *mux.Router {
|
||||||
|
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||||
return Create(s, func() *middleware.APIRateLimiter {
|
return Create(s, func() *middleware.APIRateLimiter {
|
||||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
@@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
integratedPeerValidator, err := validator.NewIntegratedValidator(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
s.PeersManager(),
|
s.PeersManager(),
|
||||||
s.SettingsManager(),
|
s.SettingsManager(),
|
||||||
|
|||||||
@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
|||||||
|
|
||||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||||
return Create(s, func() permissions.Manager {
|
return Create(s, func() permissions.Manager {
|
||||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
return permissions.NewManager(s.Store())
|
||||||
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
|
||||||
manager.SetAccountManager(s.AccountManager())
|
|
||||||
})
|
|
||||||
|
|
||||||
return manager
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
|||||||
return idpManager
|
return idpManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
|||||||
return &m
|
return &m
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||||
switch {
|
switch {
|
||||||
case s.certManager != nil:
|
case s.certManager != nil:
|
||||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
|||||||
log.Tracef("custom handler set successfully")
|
log.Tracef("custom handler set successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||||
// Check if a custom handler was set (for multiplexing additional services)
|
// Check if a custom handler was set (for multiplexing additional services)
|
||||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||||
if handler, ok := customHandler.(http.Handler); ok {
|
if handler, ok := customHandler.(http.Handler); ok {
|
||||||
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
|||||||
gRPCHandler.ServeHTTP(writer, request)
|
gRPCHandler.ServeHTTP(writer, request)
|
||||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||||
wsProxy.Handler().ServeHTTP(writer, request)
|
wsProxy.Handler().ServeHTTP(writer, request)
|
||||||
|
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||||
|
idpHandler.ServeHTTP(writer, request)
|
||||||
default:
|
default:
|
||||||
httpHandler.ServeHTTP(writer, request)
|
httpHandler.ServeHTTP(writer, request)
|
||||||
}
|
}
|
||||||
|
|||||||
813
management/internals/shared/grpc/components_encoder.go
Normal file
813
management/internals/shared/grpc/components_encoder.go
Normal file
@@ -0,0 +1,813 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// wgKeyRawLen is the raw byte length of a WireGuard public key.
|
||||||
|
const wgKeyRawLen = 32
|
||||||
|
|
||||||
|
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
|
||||||
|
// The envelope is fully self-contained — every field needed by the client's
|
||||||
|
// local Calculate() comes from the components struct itself. The only
|
||||||
|
// externally-supplied data is the receiving peer's PeerConfig (which is
|
||||||
|
// computed alongside the components in the network_map controller and reused
|
||||||
|
// from the legacy proto path) and the dns_domain string.
|
||||||
|
type ComponentsEnvelopeInput struct {
|
||||||
|
Components *types.NetworkMapComponents
|
||||||
|
PeerConfig *proto.PeerConfig
|
||||||
|
DNSDomain string
|
||||||
|
DNSForwarderPort int64
|
||||||
|
// UserIDClaim is the OIDC claim name the client should embed in
|
||||||
|
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
|
||||||
|
// is OK — client treats empty as "no SshAuth to build".
|
||||||
|
UserIDClaim string
|
||||||
|
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
|
||||||
|
// external controllers (BYOP/port-forwarding). Nil when no proxy data
|
||||||
|
// is present; encoder skips the field in that case.
|
||||||
|
ProxyPatch *proto.ProxyPatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
|
||||||
|
// wire envelope. The encoder is intentionally non-deterministic: it iterates
|
||||||
|
// Go maps in their native (random) order. Indexes inside the envelope
|
||||||
|
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
|
||||||
|
// are self-consistent within a single encode, so the decoder reconstructs
|
||||||
|
// the same typed objects regardless of emit order. Tests that need to
|
||||||
|
// compare envelopes do so semantically via proto round-trip + canonicalize,
|
||||||
|
// not byte-equal.
|
||||||
|
//
|
||||||
|
// Callers must NOT concatenate or merge envelopes from different encodes —
|
||||||
|
// index spaces are local to a single envelope.
|
||||||
|
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
|
||||||
|
c := in.Components
|
||||||
|
|
||||||
|
// Graceful degrade when components is nil — matches the legacy path's
|
||||||
|
// behaviour for missing/unvalidated peers (return a NetworkMap with only
|
||||||
|
// Network populated). The receiver gets an envelope it can decode
|
||||||
|
// without crashing; AccountSettings stays non-nil so client-side
|
||||||
|
// dereferences are safe.
|
||||||
|
if c == nil {
|
||||||
|
// Match legacy missing-peer minimum: a NetworkMap with only Network
|
||||||
|
// populated. The receiver gets enough to bootstrap (Network
|
||||||
|
// identifier, dns_domain, account_settings) and nothing else.
|
||||||
|
return &proto.NetworkMapEnvelope{
|
||||||
|
Payload: &proto.NetworkMapEnvelope_Full{
|
||||||
|
Full: &proto.NetworkMapComponentsFull{
|
||||||
|
PeerConfig: in.PeerConfig,
|
||||||
|
DnsDomain: in.DNSDomain,
|
||||||
|
DnsForwarderPort: in.DNSForwarderPort,
|
||||||
|
UserIdClaim: in.UserIDClaim,
|
||||||
|
AccountSettings: &proto.AccountSettingsCompact{},
|
||||||
|
ProxyPatch: in.ProxyPatch,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
|
||||||
|
// every regular peer (in c.Peers) must be indexed before any encoder
|
||||||
|
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
|
||||||
|
// peers that exist only in c.RouterPeers would silently lose their
|
||||||
|
// peer_index reference.
|
||||||
|
enc := newComponentEncoder(c)
|
||||||
|
enc.indexAllPeers()
|
||||||
|
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
|
||||||
|
|
||||||
|
// Phase 2: gather every policy that any consumer references (peer-pair
|
||||||
|
// policies + resource-only policies) so encodeResourcePoliciesMap can
|
||||||
|
// translate every *Policy pointer to a wire index.
|
||||||
|
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
|
||||||
|
policies, policyToIdxs := enc.encodePolicies(allPolicies)
|
||||||
|
|
||||||
|
// Phase 3: emit. Order of struct field expressions no longer matters:
|
||||||
|
// every encoder either reads from the dedup tables or works on
|
||||||
|
// independent input.
|
||||||
|
full := &proto.NetworkMapComponentsFull{
|
||||||
|
Serial: networkSerial(c.Network),
|
||||||
|
PeerConfig: in.PeerConfig,
|
||||||
|
Network: toAccountNetwork(c.Network),
|
||||||
|
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
|
||||||
|
DnsForwarderPort: in.DNSForwarderPort,
|
||||||
|
UserIdClaim: in.UserIDClaim,
|
||||||
|
ProxyPatch: in.ProxyPatch,
|
||||||
|
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
|
||||||
|
DnsDomain: in.DNSDomain,
|
||||||
|
CustomZoneDomain: c.CustomZoneDomain,
|
||||||
|
AgentVersions: enc.agentVersions,
|
||||||
|
Peers: enc.peers,
|
||||||
|
RouterPeerIndexes: routerIdxs,
|
||||||
|
Policies: policies,
|
||||||
|
Groups: enc.encodeGroups(),
|
||||||
|
Routes: enc.encodeRoutes(c.Routes),
|
||||||
|
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
|
||||||
|
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
|
||||||
|
AccountZones: encodeCustomZones(c.AccountZones),
|
||||||
|
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
|
||||||
|
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
|
||||||
|
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
|
||||||
|
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
|
||||||
|
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
|
||||||
|
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.NetworkMapEnvelope{
|
||||||
|
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
|
||||||
|
// production path always populates c.Network, but the encoder is exported
|
||||||
|
// and a hand-built components struct may omit it.
|
||||||
|
func networkSerial(n *types.Network) uint64 {
|
||||||
|
if n == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return n.CurrentSerial()
|
||||||
|
}
|
||||||
|
|
||||||
|
type componentEncoder struct {
|
||||||
|
components *types.NetworkMapComponents
|
||||||
|
|
||||||
|
peerOrder map[string]uint32
|
||||||
|
peers []*proto.PeerCompact
|
||||||
|
|
||||||
|
agentVersionOrder map[string]uint32
|
||||||
|
agentVersions []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
|
||||||
|
return &componentEncoder{
|
||||||
|
components: c,
|
||||||
|
peerOrder: make(map[string]uint32, len(c.Peers)),
|
||||||
|
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
|
||||||
|
agentVersionOrder: make(map[string]uint32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) indexAllPeers() {
|
||||||
|
for _, p := range e.components.Peers {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e.appendPeer(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
|
||||||
|
if idx, ok := e.peerOrder[p.ID]; ok {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
idx := uint32(len(e.peers))
|
||||||
|
e.peerOrder[p.ID] = idx
|
||||||
|
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
|
||||||
|
if idx, ok := e.agentVersionOrder[v]; ok {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
// Lazy-initialise the table with "" at index 0 so the empty string
|
||||||
|
// stays interchangeable with proto3's default uint32=0 — peers without
|
||||||
|
// a WtVersion don't force the table to materialise.
|
||||||
|
if v == "" {
|
||||||
|
idx := uint32(len(e.agentVersions))
|
||||||
|
if idx == 0 {
|
||||||
|
e.agentVersions = append(e.agentVersions, "")
|
||||||
|
}
|
||||||
|
e.agentVersionOrder[""] = idx
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
if len(e.agentVersions) == 0 {
|
||||||
|
e.agentVersions = append(e.agentVersions, "")
|
||||||
|
e.agentVersionOrder[""] = 0
|
||||||
|
}
|
||||||
|
idx := uint32(len(e.agentVersions))
|
||||||
|
e.agentVersionOrder[v] = idx
|
||||||
|
e.agentVersions = append(e.agentVersions, v)
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
// indexRouterPeers ensures every router peer is in the peer dedup table
|
||||||
|
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
|
||||||
|
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
|
||||||
|
// run before any encoder that resolves peer ids via e.peerOrder.
|
||||||
|
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
|
||||||
|
if len(routers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(routers))
|
||||||
|
for _, p := range routers {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, e.appendPeer(p))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
|
||||||
|
if len(e.components.Groups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
|
||||||
|
for _, g := range e.components.Groups {
|
||||||
|
if !g.HasSeqID() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peerIdxs := make([]uint32, 0, len(g.Peers))
|
||||||
|
for _, peerID := range g.Peers {
|
||||||
|
if idx, ok := e.peerOrder[peerID]; ok {
|
||||||
|
peerIdxs = append(peerIdxs, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, &proto.GroupCompact{
|
||||||
|
Id: g.AccountSeqID,
|
||||||
|
Name: g.Name,
|
||||||
|
PeerIndexes: peerIdxs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
|
||||||
|
// list and a map from policy pointer to the indexes of its emitted rules in
|
||||||
|
// that list — used by encodeResourcePoliciesMap to translate
|
||||||
|
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
|
||||||
|
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
|
||||||
|
if len(policies) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*proto.PolicyCompact, 0, len(policies))
|
||||||
|
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
|
||||||
|
|
||||||
|
for _, pol := range policies {
|
||||||
|
if !pol.HasSeqID() || !pol.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, r := range pol.Rules {
|
||||||
|
if r == nil || !r.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
|
||||||
|
out = append(out, e.encodePolicyRule(pol, r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, idxByPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
|
||||||
|
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
|
||||||
|
return &proto.PolicyCompact{
|
||||||
|
Id: pol.AccountSeqID,
|
||||||
|
Action: networkmap.GetProtoAction(string(r.Action)),
|
||||||
|
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
|
||||||
|
Bidirectional: r.Bidirectional,
|
||||||
|
Ports: portsToUint32(r.Ports),
|
||||||
|
PortRanges: portRangesToProto(r.PortRanges),
|
||||||
|
SourceGroupIds: e.groupSeqIDs(r.Sources),
|
||||||
|
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
|
||||||
|
AuthorizedUser: r.AuthorizedUser,
|
||||||
|
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
|
||||||
|
SourceResource: e.resourceToProto(r.SourceResource),
|
||||||
|
DestinationResource: e.resourceToProto(r.DestinationResource),
|
||||||
|
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
|
||||||
|
// dropping any group that has no seq id assigned.
|
||||||
|
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(src))
|
||||||
|
for _, gid := range src {
|
||||||
|
if seq, ok := e.groupSeq(gid); ok {
|
||||||
|
out = append(out, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// unionPolicies merges c.Policies with every policy referenced by
|
||||||
|
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
|
||||||
|
// policies (relevant to a NetworkResource but not to peer-pair traffic)
|
||||||
|
// only live in ResourcePoliciesMap; without this union step they'd be lost
|
||||||
|
// from the wire and the client's resource-policy lookup would come back
|
||||||
|
// empty.
|
||||||
|
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
|
||||||
|
// Fast path: non-router peers have no resource-only policies, so the
|
||||||
|
// "union" is identical to `policies`. Skip the dedup map allocation.
|
||||||
|
if len(resourcePolicies) == 0 {
|
||||||
|
return policies
|
||||||
|
}
|
||||||
|
seen := make(map[*types.Policy]struct{}, len(policies))
|
||||||
|
out := make([]*types.Policy, 0, len(policies))
|
||||||
|
for _, p := range policies {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[p]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
for _, list := range resourcePolicies {
|
||||||
|
for _, p := range list {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[p]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
|
||||||
|
// group xid → local-user names) to the wire form (map keyed by group
|
||||||
|
// account_seq_id → UserNameList). Groups without a seq id are dropped —
|
||||||
|
// matches how source/destination group references handle the same case.
|
||||||
|
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.UserNameList, len(m))
|
||||||
|
for groupID, names := range m {
|
||||||
|
seq, ok := e.groupSeq(groupID)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
|
||||||
|
g, ok := e.components.Groups[groupID]
|
||||||
|
if !ok || !g.HasSeqID() {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return g.AccountSeqID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// resourceToProto translates types.Resource for the wire. For peer-typed
|
||||||
|
// resources the peer id is converted to a peer index into the envelope's
|
||||||
|
// peers array. For other resource types only the type string is shipped
|
||||||
|
// today (Calculate's resource-typed rule path consults SourceResource only
|
||||||
|
// for "peer" — other types fall through to group-based lookup).
|
||||||
|
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
|
||||||
|
if r.ID == "" && r.Type == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &proto.ResourceCompact{Type: string(r.Type)}
|
||||||
|
if r.Type == types.ResourceTypePeer && r.ID != "" {
|
||||||
|
if idx, ok := e.peerOrder[r.ID]; ok {
|
||||||
|
out.PeerIndexSet = true
|
||||||
|
out.PeerIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// postureCheckSeqs translates a slice of posture-check xids to their
|
||||||
|
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
|
||||||
|
// lookup. Unresolvable xids are silently dropped — matches how group/peer
|
||||||
|
// references handle the same case.
|
||||||
|
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
|
||||||
|
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(xids))
|
||||||
|
for _, xid := range xids {
|
||||||
|
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
|
||||||
|
out = append(out, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkSeq translates a Network xid to its per-account integer id using
|
||||||
|
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
|
||||||
|
// the xid isn't known — callers decide whether to skip the parent record.
|
||||||
|
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
|
||||||
|
if xid == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
seq, ok := e.components.NetworkXIDToSeq[xid]
|
||||||
|
if !ok || seq == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return seq, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
|
||||||
|
if s == nil || len(s.DisabledManagementGroups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &proto.DNSSettingsCompact{
|
||||||
|
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
|
||||||
|
}
|
||||||
|
for _, gid := range s.DisabledManagementGroups {
|
||||||
|
if seq, ok := e.groupSeq(gid); ok {
|
||||||
|
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
|
||||||
|
if len(routes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.RouteRaw, 0, len(routes))
|
||||||
|
for _, r := range routes {
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rr := &proto.RouteRaw{
|
||||||
|
Id: r.AccountSeqID,
|
||||||
|
NetId: string(r.NetID),
|
||||||
|
Description: r.Description,
|
||||||
|
KeepRoute: r.KeepRoute,
|
||||||
|
NetworkType: int32(r.NetworkType),
|
||||||
|
Masquerade: r.Masquerade,
|
||||||
|
Metric: int32(r.Metric),
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
SkipAutoApply: r.SkipAutoApply,
|
||||||
|
Domains: r.Domains.ToPunycodeList(),
|
||||||
|
GroupIds: e.groupIDsToSeq(r.Groups),
|
||||||
|
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
|
||||||
|
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||||
|
}
|
||||||
|
if r.Network.IsValid() {
|
||||||
|
rr.NetworkCidr = r.Network.String()
|
||||||
|
}
|
||||||
|
if r.Peer != "" {
|
||||||
|
if idx, ok := e.peerOrder[r.Peer]; ok {
|
||||||
|
rr.PeerIndexSet = true
|
||||||
|
rr.PeerIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, rr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(groupIDs))
|
||||||
|
for _, gid := range groupIDs {
|
||||||
|
if seq, ok := e.groupSeq(gid); ok {
|
||||||
|
out = append(out, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
|
||||||
|
if len(nsgs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
|
||||||
|
for _, nsg := range nsgs {
|
||||||
|
if nsg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := &proto.NameServerGroupRaw{
|
||||||
|
Id: nsg.AccountSeqID,
|
||||||
|
Name: nsg.Name,
|
||||||
|
Description: nsg.Description,
|
||||||
|
Nameservers: encodeNameServers(nsg.NameServers),
|
||||||
|
GroupIds: e.groupIDsToSeq(nsg.Groups),
|
||||||
|
Primary: nsg.Primary,
|
||||||
|
Domains: nsg.Domains,
|
||||||
|
Enabled: nsg.Enabled,
|
||||||
|
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||||
|
}
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.NameServer, 0, len(servers))
|
||||||
|
for _, s := range servers {
|
||||||
|
out = append(out, &proto.NameServer{
|
||||||
|
IP: s.IP.String(),
|
||||||
|
NSType: int64(s.NSType),
|
||||||
|
Port: int64(s.Port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.SimpleRecord, 0, len(records))
|
||||||
|
for _, r := range records {
|
||||||
|
out = append(out, &proto.SimpleRecord{
|
||||||
|
Name: r.Name,
|
||||||
|
Type: int64(r.Type),
|
||||||
|
Class: r.Class,
|
||||||
|
TTL: int64(r.TTL),
|
||||||
|
RData: r.RData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
|
||||||
|
if len(zones) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.CustomZone, 0, len(zones))
|
||||||
|
for _, z := range zones {
|
||||||
|
out = append(out, &proto.CustomZone{
|
||||||
|
Domain: z.Domain,
|
||||||
|
Records: encodeSimpleRecords(z.Records),
|
||||||
|
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||||
|
NonAuthoritative: z.NonAuthoritative,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
|
||||||
|
if len(resources) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
|
||||||
|
for _, r := range resources {
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := &proto.NetworkResourceRaw{
|
||||||
|
Id: r.AccountSeqID,
|
||||||
|
Name: r.Name,
|
||||||
|
Description: r.Description,
|
||||||
|
Type: string(r.Type),
|
||||||
|
Address: r.Address,
|
||||||
|
DomainValue: r.Domain,
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
}
|
||||||
|
if seq, ok := e.networkSeq(r.NetworkID); ok {
|
||||||
|
entry.NetworkSeq = seq
|
||||||
|
}
|
||||||
|
if r.Prefix.IsValid() {
|
||||||
|
entry.PrefixCidr = r.Prefix.String()
|
||||||
|
}
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
|
||||||
|
if len(routersMap) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
|
||||||
|
for networkXID, routers := range routersMap {
|
||||||
|
if len(routers) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
netSeq, ok := e.networkSeq(networkXID)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
|
||||||
|
for peerID, r := range routers {
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := &proto.NetworkRouterEntry{
|
||||||
|
Id: r.AccountSeqID,
|
||||||
|
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||||
|
Masquerade: r.Masquerade,
|
||||||
|
Metric: int32(r.Metric),
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
}
|
||||||
|
if idx, ok := e.peerOrder[peerID]; ok {
|
||||||
|
entry.PeerIndexSet = true
|
||||||
|
entry.PeerIndex = idx
|
||||||
|
}
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
|
||||||
|
if len(rpm) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
|
||||||
|
// (small slice). Network resources without seq id are dropped, matching how
|
||||||
|
// other components-without-seq are silently filtered.
|
||||||
|
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
|
||||||
|
for _, r := range e.components.NetworkResources {
|
||||||
|
if r != nil && r.AccountSeqID != 0 {
|
||||||
|
resourceXIDToSeq[r.ID] = r.AccountSeqID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
|
||||||
|
for resourceXID, policies := range rpm {
|
||||||
|
seq, ok := resourceXIDToSeq[resourceXID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxs := make([]uint32, 0, len(policies)*2)
|
||||||
|
for _, pol := range policies {
|
||||||
|
idxs = append(idxs, policyToIdxs[pol]...)
|
||||||
|
}
|
||||||
|
if len(idxs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.UserIDList, len(m))
|
||||||
|
for groupID, userIDs := range m {
|
||||||
|
seq, ok := e.groupSeq(groupID)
|
||||||
|
if !ok || len(userIDs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.UserIDList{UserIds: userIDs}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSetToSlice(s map[string]struct{}) []string {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(s))
|
||||||
|
for k := range s {
|
||||||
|
out = append(out, k)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[uint32]*proto.PeerIndexSet, len(m))
|
||||||
|
for checkXID, failedPeerIDs := range m {
|
||||||
|
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
|
||||||
|
if !ok || seq == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxs := make([]uint32, 0, len(failedPeerIDs))
|
||||||
|
for peerID := range failedPeerIDs {
|
||||||
|
if idx, ok := e.peerOrder[peerID]; ok {
|
||||||
|
idxs = append(idxs, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(idxs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// toAccountSettingsCompact always returns a non-nil message — the client
|
||||||
|
// dereferences it unconditionally during Calculate(), so a nil here would
|
||||||
|
// crash the receiver. A missing types.AccountSettingsInfo on the server
|
||||||
|
// (which shouldn't happen in production but the encoder is exported)
|
||||||
|
// degrades to login_expiration_enabled = false, which makes
|
||||||
|
// LoginExpired() return false for every peer.
|
||||||
|
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
|
||||||
|
if s == nil {
|
||||||
|
return &proto.AccountSettingsCompact{}
|
||||||
|
}
|
||||||
|
return &proto.AccountSettingsCompact{
|
||||||
|
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||||
|
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
|
||||||
|
if n == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &proto.AccountNetwork{
|
||||||
|
Identifier: n.Identifier,
|
||||||
|
NetCidr: n.Net.String(),
|
||||||
|
Dns: n.Dns,
|
||||||
|
Serial: n.CurrentSerial(),
|
||||||
|
}
|
||||||
|
if len(n.NetV6.IP) > 0 {
|
||||||
|
out.NetV6Cidr = n.NetV6.String()
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
|
||||||
|
pc := &proto.PeerCompact{
|
||||||
|
WgPubKey: decodeWgKey(p.Key),
|
||||||
|
SshPubKey: []byte(p.SSHKey),
|
||||||
|
DnsLabel: p.DNSLabel,
|
||||||
|
AgentVersionIdx: agentVersionIdx,
|
||||||
|
AddedWithSsoLogin: p.UserID != "",
|
||||||
|
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||||
|
SshEnabled: p.SSHEnabled,
|
||||||
|
SupportsIpv6: p.SupportsIPv6(),
|
||||||
|
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
|
||||||
|
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
|
||||||
|
}
|
||||||
|
if p.LastLogin != nil {
|
||||||
|
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case !p.IP.IsValid():
|
||||||
|
// leave Ip nil
|
||||||
|
case p.IP.Is4() || p.IP.Is4In6():
|
||||||
|
ip := p.IP.Unmap().As4()
|
||||||
|
pc.Ip = ip[:]
|
||||||
|
default:
|
||||||
|
ip := p.IP.As16()
|
||||||
|
pc.Ip = ip[:]
|
||||||
|
}
|
||||||
|
if p.IPv6.IsValid() {
|
||||||
|
ip := p.IPv6.As16()
|
||||||
|
pc.Ipv6 = ip[:]
|
||||||
|
}
|
||||||
|
return pc
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
|
||||||
|
// key, or nil for an empty / malformed key.
|
||||||
|
func decodeWgKey(s string) []byte {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]byte, wgKeyRawLen)
|
||||||
|
n, err := base64.StdEncoding.Decode(out, []byte(s))
|
||||||
|
if err != nil || n != wgKeyRawLen {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func portsToUint32(ports []string) []uint32 {
|
||||||
|
if len(ports) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]uint32, 0, len(ports))
|
||||||
|
for _, p := range ports {
|
||||||
|
v, err := strconv.ParseUint(p, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, uint32(v))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
|
||||||
|
if len(ranges) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*proto.PortInfo_Range, 0, len(ranges))
|
||||||
|
for _, r := range ranges {
|
||||||
|
out = append(out, &proto.PortInfo_Range{
|
||||||
|
Start: uint32(r.Start),
|
||||||
|
End: uint32(r.End),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"cmp"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
goproto "google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||||
|
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||||
|
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||||
|
|
||||||
|
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
|
||||||
|
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
|
||||||
|
// to reference the new peer indexes. Groups, policies, and router indexes are
|
||||||
|
// also sorted. After canonicalize, two envelopes built from the same logical
|
||||||
|
// input compare byte-equal via proto.Equal.
|
||||||
|
//
|
||||||
|
// This lives on the test side — the encoder itself emits in map-iteration
|
||||||
|
// order. Test-side normalization is the contract for "two encodes are
|
||||||
|
// equivalent".
|
||||||
|
func canonicalize(full *proto.NetworkMapComponentsFull) {
|
||||||
|
if full == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Canonicalize agent_versions first: sort the slice and rewrite each
|
||||||
|
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
|
||||||
|
// index 0 by convention.
|
||||||
|
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
|
||||||
|
if len(full.AgentVersions) > 0 {
|
||||||
|
// Pair version → original index, sort, rebuild.
|
||||||
|
type avEntry struct {
|
||||||
|
version string
|
||||||
|
oldIdx uint32
|
||||||
|
}
|
||||||
|
entries := make([]avEntry, len(full.AgentVersions))
|
||||||
|
for i, v := range full.AgentVersions {
|
||||||
|
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
|
||||||
|
}
|
||||||
|
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
|
||||||
|
// keeps the canonicalize output stable when two entries compare
|
||||||
|
// equal (the encoder dedups, but defending against future inputs).
|
||||||
|
slices.SortFunc(entries, func(a, b avEntry) int {
|
||||||
|
if a.version == "" && b.version != "" {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if b.version == "" && a.version != "" {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if c := cmp.Compare(a.version, b.version); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return cmp.Compare(a.oldIdx, b.oldIdx)
|
||||||
|
})
|
||||||
|
newVersions := make([]string, len(entries))
|
||||||
|
for newIdx, e := range entries {
|
||||||
|
avRemap[e.oldIdx] = uint32(newIdx)
|
||||||
|
newVersions[newIdx] = e.version
|
||||||
|
}
|
||||||
|
full.AgentVersions = newVersions
|
||||||
|
}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
|
||||||
|
p.AgentVersionIdx = newIdx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerEntry struct {
|
||||||
|
peer *proto.PeerCompact
|
||||||
|
oldIdx uint32
|
||||||
|
}
|
||||||
|
entries := make([]peerEntry, len(full.Peers))
|
||||||
|
for i, p := range full.Peers {
|
||||||
|
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
|
||||||
|
}
|
||||||
|
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
|
||||||
|
// nil from malformed keys, or both empty for placeholders).
|
||||||
|
slices.SortFunc(entries, func(a, b peerEntry) int {
|
||||||
|
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
|
||||||
|
})
|
||||||
|
|
||||||
|
remap := make(map[uint32]uint32, len(entries))
|
||||||
|
newPeers := make([]*proto.PeerCompact, len(entries))
|
||||||
|
for newIdx, e := range entries {
|
||||||
|
remap[e.oldIdx] = uint32(newIdx)
|
||||||
|
newPeers[newIdx] = e.peer
|
||||||
|
}
|
||||||
|
full.Peers = newPeers
|
||||||
|
|
||||||
|
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
|
||||||
|
for _, g := range full.Groups {
|
||||||
|
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
|
||||||
|
}
|
||||||
|
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
|
||||||
|
|
||||||
|
for _, r := range full.Routes {
|
||||||
|
if r.PeerIndexSet {
|
||||||
|
if newIdx, ok := remap[r.PeerIndex]; ok {
|
||||||
|
r.PeerIndex = newIdx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(r.GroupIds)
|
||||||
|
slices.Sort(r.AccessControlGroupIds)
|
||||||
|
slices.Sort(r.PeerGroupIds)
|
||||||
|
}
|
||||||
|
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
|
||||||
|
|
||||||
|
for _, list := range full.RoutersMap {
|
||||||
|
for _, entry := range list.Entries {
|
||||||
|
if entry.PeerIndexSet {
|
||||||
|
if newIdx, ok := remap[entry.PeerIndex]; ok {
|
||||||
|
entry.PeerIndex = newIdx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(entry.PeerGroupIds)
|
||||||
|
}
|
||||||
|
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, set := range full.PostureFailedPeers {
|
||||||
|
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range full.Policies {
|
||||||
|
slices.Sort(p.SourceGroupIds)
|
||||||
|
slices.Sort(p.DestinationGroupIds)
|
||||||
|
}
|
||||||
|
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
|
||||||
|
// multiple PolicyCompact entries sharing the same Id (one per rule, when
|
||||||
|
// a Policy has multiple rules) still get a deterministic order. After
|
||||||
|
// sorting we remap indexes in ResourcePoliciesMap.
|
||||||
|
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
|
||||||
|
for i, p := range full.Policies {
|
||||||
|
policyOldOrder[p] = uint32(i)
|
||||||
|
}
|
||||||
|
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
|
||||||
|
if c := cmp.Compare(a.Id, b.Id); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
|
||||||
|
})
|
||||||
|
policyRemap := make(map[uint32]uint32, len(full.Policies))
|
||||||
|
for newIdx, p := range full.Policies {
|
||||||
|
policyRemap[policyOldOrder[p]] = uint32(newIdx)
|
||||||
|
}
|
||||||
|
for _, idxs := range full.ResourcePoliciesMap {
|
||||||
|
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
|
||||||
|
}
|
||||||
|
for _, list := range full.GroupIdToUserIds {
|
||||||
|
slices.Sort(list.UserIds)
|
||||||
|
}
|
||||||
|
slices.Sort(full.AllowedUserIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
|
||||||
|
out := make([]uint32, 0, len(idxs))
|
||||||
|
for _, i := range idxs {
|
||||||
|
if newIdx, ok := remap[i]; ok {
|
||||||
|
out = append(out, newIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
|
||||||
|
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
|
||||||
|
// the encoder is intentionally non-deterministic.
|
||||||
|
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
|
||||||
|
canonicalize(a.GetFull())
|
||||||
|
canonicalize(b.GetFull())
|
||||||
|
return goproto.Equal(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestComponents() *types.NetworkMapComponents {
|
||||||
|
peerA := &nbpeer.Peer{
|
||||||
|
ID: "peer-a",
|
||||||
|
Key: testWgKeyA,
|
||||||
|
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||||
|
DNSLabel: "peera",
|
||||||
|
SSHKey: "ssh-a",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
peerB := &nbpeer.Peer{
|
||||||
|
ID: "peer-b",
|
||||||
|
Key: testWgKeyB,
|
||||||
|
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||||
|
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
|
||||||
|
DNSLabel: "peerb",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
|
||||||
|
}
|
||||||
|
peerC := &nbpeer.Peer{
|
||||||
|
ID: "peer-c",
|
||||||
|
Key: testWgKeyC,
|
||||||
|
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||||
|
DNSLabel: "peerc",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.NetworkMapComponents{
|
||||||
|
PeerID: "peer-a",
|
||||||
|
Network: &types.Network{
|
||||||
|
Identifier: "net-test",
|
||||||
|
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||||
|
Serial: 7,
|
||||||
|
},
|
||||||
|
AccountSettings: &types.AccountSettingsInfo{
|
||||||
|
PeerLoginExpirationEnabled: true,
|
||||||
|
PeerLoginExpiration: 2 * time.Hour,
|
||||||
|
},
|
||||||
|
Peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-a": peerA,
|
||||||
|
"peer-b": peerB,
|
||||||
|
"peer-c": peerC,
|
||||||
|
},
|
||||||
|
Groups: map[string]*types.Group{
|
||||||
|
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
|
||||||
|
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
|
||||||
|
},
|
||||||
|
Policies: []*types.Policy{
|
||||||
|
{
|
||||||
|
ID: "pol-1",
|
||||||
|
AccountSeqID: 10,
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
|
||||||
|
Ports: []string{"22", "80"},
|
||||||
|
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: c,
|
||||||
|
DNSDomain: "netbird.cloud",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotNil(t, env)
|
||||||
|
full := env.GetFull()
|
||||||
|
require.NotNil(t, full, "envelope must contain Full payload")
|
||||||
|
|
||||||
|
assert.EqualValues(t, 7, full.Serial)
|
||||||
|
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||||
|
|
||||||
|
require.NotNil(t, full.Network)
|
||||||
|
assert.Equal(t, "net-test", full.Network.Identifier)
|
||||||
|
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
|
||||||
|
|
||||||
|
require.NotNil(t, full.AccountSettings)
|
||||||
|
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||||
|
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
|
||||||
|
|
||||||
|
require.Len(t, full.Peers, 3)
|
||||||
|
byLabel := map[string]*proto.PeerCompact{}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
|
||||||
|
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
|
||||||
|
byLabel[p.DnsLabel] = p
|
||||||
|
}
|
||||||
|
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
|
||||||
|
// Hammer it 100 times — Go map iteration is randomized per call, so each
|
||||||
|
// run produces different wire bytes, but the canonicalized form must
|
||||||
|
// match.
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
require.True(t, envelopesEquivalent(expected, got),
|
||||||
|
"encode #%d must be semantically equivalent to first encode", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
|
||||||
|
const goroutines = 50
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
results := make([]*proto.NetworkMapEnvelope, goroutines)
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i, got := range results {
|
||||||
|
require.NotNil(t, got, "goroutine %d returned nil", i)
|
||||||
|
require.True(t, envelopesEquivalent(expected, got),
|
||||||
|
"goroutine %d produced inequivalent envelope", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Groups, 2)
|
||||||
|
|
||||||
|
groupByID := map[uint32]*proto.GroupCompact{}
|
||||||
|
for _, g := range full.Groups {
|
||||||
|
groupByID[g.Id] = g
|
||||||
|
}
|
||||||
|
require.Contains(t, groupByID, uint32(1))
|
||||||
|
require.Contains(t, groupByID, uint32(2))
|
||||||
|
assert.Equal(t, "Src", groupByID[1].Name)
|
||||||
|
assert.Equal(t, "Dst", groupByID[2].Name)
|
||||||
|
assert.Len(t, groupByID[1].PeerIndexes, 1)
|
||||||
|
assert.Len(t, groupByID[2].PeerIndexes, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Policies, 1)
|
||||||
|
pc := full.Policies[0]
|
||||||
|
assert.EqualValues(t, 10, pc.Id)
|
||||||
|
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
|
||||||
|
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
|
||||||
|
assert.True(t, pc.Bidirectional)
|
||||||
|
assert.Equal(t, []uint32{22, 80}, pc.Ports)
|
||||||
|
require.Len(t, pc.PortRanges, 1)
|
||||||
|
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
|
||||||
|
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
|
||||||
|
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
|
||||||
|
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.RouterPeerIndexes, 1)
|
||||||
|
idx := full.RouterPeerIndexes[0]
|
||||||
|
require.Less(t, int(idx), len(full.Peers))
|
||||||
|
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
|
||||||
|
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
|
||||||
|
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
|
||||||
|
"two distinct versions, order depends on map iteration")
|
||||||
|
|
||||||
|
idxByLabel := map[string]uint32{}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
|
||||||
|
}
|
||||||
|
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
|
||||||
|
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Policies[0].Enabled = false
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
assert.Empty(t, full.Policies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Groups["group-src"].AccountSeqID = 0
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
|
||||||
|
assert.EqualValues(t, 2, full.Groups[0].Id)
|
||||||
|
|
||||||
|
require.Len(t, full.Policies, 1)
|
||||||
|
pc := full.Policies[0]
|
||||||
|
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
|
||||||
|
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
|
||||||
|
// Both peers have nil WgPubKey after decode; canonicalize must still
|
||||||
|
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
|
||||||
|
// canonicalize identically.
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Peers["peer-a"].Key = "garbage-a-!!!"
|
||||||
|
c.Peers["peer-b"].Key = "garbage-b-!!!"
|
||||||
|
|
||||||
|
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
require.True(t, envelopesEquivalent(expected, got),
|
||||||
|
"encode #%d with two same-key peers must canonicalize equivalently", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Peers["peer-a"].Key = "not-base64-!!!"
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Peers, 3)
|
||||||
|
|
||||||
|
var byLabel = map[string]*proto.PeerCompact{}
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
byLabel[p.DnsLabel] = p
|
||||||
|
}
|
||||||
|
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
|
||||||
|
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
v6Only := &nbpeer.Peer{
|
||||||
|
ID: "peer-v6",
|
||||||
|
Key: testWgKeyA,
|
||||||
|
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
|
||||||
|
DNSLabel: "peerv6",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
c.Peers["peer-v6"] = v6Only
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
var found *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peerv6" {
|
||||||
|
found = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, found, "ipv6-only peer must be present")
|
||||||
|
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
|
||||||
|
assert.Len(t, found.Ipv6, 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Peers["peer-noip"] = &nbpeer.Peer{
|
||||||
|
ID: "peer-noip",
|
||||||
|
Key: testWgKeyA,
|
||||||
|
DNSLabel: "peernoip",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
var found *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peernoip" {
|
||||||
|
found = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, found)
|
||||||
|
assert.Empty(t, found.Ip)
|
||||||
|
assert.Empty(t, found.Ipv6)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
|
||||||
|
c := &types.NetworkMapComponents{
|
||||||
|
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||||
|
|
||||||
|
full := env.GetFull()
|
||||||
|
require.NotNil(t, full)
|
||||||
|
assert.Empty(t, full.Peers)
|
||||||
|
assert.Empty(t, full.Groups)
|
||||||
|
assert.Empty(t, full.Policies)
|
||||||
|
assert.Empty(t, full.RouterPeerIndexes)
|
||||||
|
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||||
|
c.Peers["peer-a"].UserID = "user-1"
|
||||||
|
c.Peers["peer-a"].LoginExpirationEnabled = true
|
||||||
|
c.Peers["peer-a"].LastLogin = &now
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
var pa *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peera" {
|
||||||
|
pa = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, pa)
|
||||||
|
assert.True(t, pa.AddedWithSsoLogin)
|
||||||
|
assert.True(t, pa.LoginExpirationEnabled)
|
||||||
|
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
|
||||||
|
|
||||||
|
// peer-b has no UserID and no LastLogin → all fields zero-value.
|
||||||
|
var pb *proto.PeerCompact
|
||||||
|
for _, p := range full.Peers {
|
||||||
|
if p.DnsLabel == "peerb" {
|
||||||
|
pb = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, pb)
|
||||||
|
assert.False(t, pb.AddedWithSsoLogin)
|
||||||
|
assert.False(t, pb.LoginExpirationEnabled)
|
||||||
|
assert.Zero(t, pb.LastLoginUnixNano)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Routes = []*nbroute.Route{
|
||||||
|
{
|
||||||
|
ID: "route-peer",
|
||||||
|
AccountSeqID: 100,
|
||||||
|
NetID: "net-A",
|
||||||
|
Description: "via peer-c",
|
||||||
|
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||||
|
Peer: "peer-c", // peer ID, not WG key
|
||||||
|
Groups: []string{"group-src"},
|
||||||
|
AccessControlGroups: []string{"group-dst"},
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "route-peergroup",
|
||||||
|
AccountSeqID: 101,
|
||||||
|
NetID: "net-B",
|
||||||
|
Network: netip.MustParsePrefix("10.1.0.0/16"),
|
||||||
|
PeerGroups: []string{"group-src", "group-dst"},
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "route-no-seq",
|
||||||
|
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
|
||||||
|
Network: netip.MustParsePrefix("10.2.0.0/16"),
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Routes, 3)
|
||||||
|
byNetID := map[string]*proto.RouteRaw{}
|
||||||
|
for _, r := range full.Routes {
|
||||||
|
byNetID[r.NetId] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
r1 := byNetID["net-A"]
|
||||||
|
require.NotNil(t, r1)
|
||||||
|
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
|
||||||
|
require.Less(t, int(r1.PeerIndex), len(full.Peers))
|
||||||
|
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
|
||||||
|
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
|
||||||
|
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
|
||||||
|
assert.Empty(t, r1.PeerGroupIds)
|
||||||
|
|
||||||
|
r2 := byNetID["net-B"]
|
||||||
|
require.NotNil(t, r2)
|
||||||
|
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
|
||||||
|
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.Routes = []*nbroute.Route{{
|
||||||
|
ID: "route-x",
|
||||||
|
AccountSeqID: 100,
|
||||||
|
Peer: "peer-not-in-components",
|
||||||
|
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||||
|
Enabled: true,
|
||||||
|
}}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Routes, 1)
|
||||||
|
assert.False(t, full.Routes[0].PeerIndexSet,
|
||||||
|
"missing peer reference must not pretend to point at peer index 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
|
||||||
|
// is the I1 case — without unionPolicies the encoder would silently
|
||||||
|
// drop it from the wire.
|
||||||
|
resourceOnlyPolicy := &types.Policy{
|
||||||
|
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
c.ResourcePoliciesMap = map[string][]*types.Policy{
|
||||||
|
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
|
||||||
|
}
|
||||||
|
// Resource must appear in components.NetworkResources with a seq id —
|
||||||
|
// encoder uses that to translate the xid map key to uint32.
|
||||||
|
c.NetworkResources = []*resourceTypes.NetworkResource{
|
||||||
|
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
|
||||||
|
|
||||||
|
policyByID := map[uint32]*proto.PolicyCompact{}
|
||||||
|
policyIdxByID := map[uint32]uint32{}
|
||||||
|
for i, p := range full.Policies {
|
||||||
|
policyByID[p.Id] = p
|
||||||
|
policyIdxByID[p.Id] = uint32(i)
|
||||||
|
}
|
||||||
|
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
|
||||||
|
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
|
||||||
|
|
||||||
|
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
|
||||||
|
idxs := full.ResourcePoliciesMap[77].Indexes
|
||||||
|
require.Len(t, idxs, 2)
|
||||||
|
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
|
||||||
|
"resource policies map must reference both wire policy indexes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.NameServerGroups = []*nbdns.NameServerGroup{{
|
||||||
|
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
|
||||||
|
NameServers: []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
|
||||||
|
}},
|
||||||
|
Groups: []string{"group-src", "group-not-persisted"},
|
||||||
|
Primary: true, Enabled: true,
|
||||||
|
Domains: []string{"corp.example"},
|
||||||
|
}}
|
||||||
|
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.NameserverGroups, 1)
|
||||||
|
nsg := full.NameserverGroups[0]
|
||||||
|
assert.EqualValues(t, 50, nsg.Id)
|
||||||
|
assert.Equal(t, "Main", nsg.Name)
|
||||||
|
assert.True(t, nsg.Primary)
|
||||||
|
require.Len(t, nsg.Nameservers, 1)
|
||||||
|
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
|
||||||
|
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
|
||||||
|
c.PostureFailedPeers = map[string]map[string]struct{}{
|
||||||
|
"check-1": {
|
||||||
|
"peer-a": {},
|
||||||
|
"peer-b": {},
|
||||||
|
"peer-not-in-account": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Contains(t, full.PostureFailedPeers, uint32(33))
|
||||||
|
idxs := full.PostureFailedPeers[33].PeerIndexes
|
||||||
|
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||||
|
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||||
|
"net-1": {
|
||||||
|
"peer-c": {
|
||||||
|
ID: "router-1", AccountSeqID: 200,
|
||||||
|
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Contains(t, full.RoutersMap, uint32(5))
|
||||||
|
entries := full.RoutersMap[5].Entries
|
||||||
|
require.Len(t, entries, 1)
|
||||||
|
e := entries[0]
|
||||||
|
assert.EqualValues(t, 200, e.Id)
|
||||||
|
assert.True(t, e.PeerIndexSet)
|
||||||
|
require.Less(t, int(e.PeerIndex), len(full.Peers))
|
||||||
|
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
|
||||||
|
assert.True(t, e.Masquerade)
|
||||||
|
assert.EqualValues(t, 10, e.Metric)
|
||||||
|
assert.True(t, e.Enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
|
||||||
|
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
|
||||||
|
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
|
||||||
|
// peer_index reference must still resolve.
|
||||||
|
c := newTestComponents()
|
||||||
|
delete(c.Peers, "peer-c")
|
||||||
|
routerPeer := &nbpeer.Peer{
|
||||||
|
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||||
|
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}
|
||||||
|
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
|
||||||
|
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||||
|
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||||
|
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Contains(t, full.RoutersMap, uint32(5))
|
||||||
|
require.Len(t, full.RoutersMap[5].Entries, 1)
|
||||||
|
e := full.RoutersMap[5].Entries[0]
|
||||||
|
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.DNSSettings = &types.DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
|
||||||
|
}
|
||||||
|
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.DnsSettings)
|
||||||
|
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
|
||||||
|
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
c.GroupIDToUserIDs = map[string][]string{
|
||||||
|
"group-src": {"user-1", "user-2"},
|
||||||
|
"group-no-seq": {"user-3"}, // group not persisted → drop
|
||||||
|
"group-missing": {"user-4"}, // group not in components → drop
|
||||||
|
}
|
||||||
|
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
|
||||||
|
require.Contains(t, full.GroupIdToUserIds, uint32(1))
|
||||||
|
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
|
||||||
|
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
|
||||||
|
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
|
||||||
|
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
|
||||||
|
nm := &types.NetworkMap{
|
||||||
|
Peers: []*nbpeer.Peer{{
|
||||||
|
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
|
||||||
|
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||||
|
}},
|
||||||
|
FirewallRules: []*types.FirewallRule{{
|
||||||
|
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
patch := toProxyPatch(nm, "netbird.cloud", false, false)
|
||||||
|
|
||||||
|
require.NotNil(t, patch)
|
||||||
|
assert.Len(t, patch.Peers, 1)
|
||||||
|
assert.Len(t, patch.FirewallRules, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
|
||||||
|
// pass-through in both encoder branches (normal path + nil-Components
|
||||||
|
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
|
||||||
|
// from one of the envelope struct literals.
|
||||||
|
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
|
||||||
|
patch := &proto.ProxyPatch{
|
||||||
|
ForwardingRules: []*proto.ForwardingRule{{
|
||||||
|
Protocol: proto.RuleProtocol_TCP,
|
||||||
|
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
|
||||||
|
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
|
||||||
|
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("normal_path", func(t *testing.T) {
|
||||||
|
c := newTestComponents()
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: c,
|
||||||
|
ProxyPatch: patch,
|
||||||
|
}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
|
||||||
|
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: nil,
|
||||||
|
ProxyPatch: patch,
|
||||||
|
}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
|
||||||
|
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
|
||||||
|
// nil Components → minimal envelope, no crash. Matches the legacy
|
||||||
|
// behaviour for missing/unvalidated peers.
|
||||||
|
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: nil,
|
||||||
|
DNSDomain: "netbird.cloud",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotNil(t, env)
|
||||||
|
full := env.GetFull()
|
||||||
|
require.NotNil(t, full)
|
||||||
|
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
|
||||||
|
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||||
|
assert.Empty(t, full.Peers)
|
||||||
|
assert.Empty(t, full.Policies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
|
||||||
|
c := &types.NetworkMapComponents{
|
||||||
|
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||||
|
// AccountSettings deliberately nil
|
||||||
|
}
|
||||||
|
|
||||||
|
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||||
|
|
||||||
|
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
|
||||||
|
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||||
|
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
|
||||||
|
}
|
||||||
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToComponentSyncResponse builds a SyncResponse carrying the compact
|
||||||
|
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
|
||||||
|
// field is intentionally left empty — capable peers ignore it and the
|
||||||
|
// envelope alone is the authoritative wire shape.
|
||||||
|
//
|
||||||
|
// PeerConfig is computed once server-side using the receiving peer's own
|
||||||
|
// account-level network metadata. EnableSSH inside PeerConfig is left at
|
||||||
|
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
|
||||||
|
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
|
||||||
|
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
|
||||||
|
// client even though the server-side PeerConfig reports false.
|
||||||
|
func ToComponentSyncResponse(
|
||||||
|
ctx context.Context,
|
||||||
|
config *nbconfig.Config,
|
||||||
|
httpConfig *nbconfig.HttpServerConfig,
|
||||||
|
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
|
||||||
|
peer *nbpeer.Peer,
|
||||||
|
turnCredentials *Token,
|
||||||
|
relayCredentials *Token,
|
||||||
|
components *types.NetworkMapComponents,
|
||||||
|
proxyPatch *types.NetworkMap,
|
||||||
|
dnsName string,
|
||||||
|
checks []*posture.Checks,
|
||||||
|
settings *types.Settings,
|
||||||
|
extraSettings *types.ExtraSettings,
|
||||||
|
peerGroups []string,
|
||||||
|
dnsFwdPort int64,
|
||||||
|
) *proto.SyncResponse {
|
||||||
|
network := networkOrZero(components)
|
||||||
|
enableSSH := computeSSHEnabledForPeer(components, peer)
|
||||||
|
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
|
||||||
|
|
||||||
|
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||||
|
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||||
|
|
||||||
|
userIDClaim := auth.DefaultUserIDClaim
|
||||||
|
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||||
|
userIDClaim = httpConfig.AuthUserIDClaim
|
||||||
|
}
|
||||||
|
|
||||||
|
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||||
|
Components: components,
|
||||||
|
PeerConfig: peerConfig,
|
||||||
|
DNSDomain: dnsName,
|
||||||
|
DNSForwarderPort: dnsFwdPort,
|
||||||
|
UserIDClaim: userIDClaim,
|
||||||
|
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := &proto.SyncResponse{
|
||||||
|
PeerConfig: peerConfig,
|
||||||
|
NetworkMapEnvelope: envelope,
|
||||||
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
|
}
|
||||||
|
|
||||||
|
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||||
|
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkOrZero returns components.Network or a zero Network — toPeerConfig
|
||||||
|
// dereferences network.Net which would panic on nil.
|
||||||
|
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
|
||||||
|
if c == nil || c.Network == nil {
|
||||||
|
return &types.Network{}
|
||||||
|
}
|
||||||
|
return c.Network
|
||||||
|
}
|
||||||
|
|
||||||
|
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
|
||||||
|
// patch the components envelope ships alongside. Returns nil when there are
|
||||||
|
// no fragments to merge — proto3 omits a nil message field, so the receiver
|
||||||
|
// sees no patch and skips the merge step entirely.
|
||||||
|
//
|
||||||
|
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
|
||||||
|
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
|
||||||
|
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
|
||||||
|
// delivers fragments pre-expanded — there's no raw component shape to
|
||||||
|
// derive them from. Components purity isn't violated: proxy data isn't
|
||||||
|
// policy-graph-derived, it's externally injected post-Calculate, so the
|
||||||
|
// client merges it on top of its locally-computed NetworkMap.
|
||||||
|
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
|
||||||
|
if nm == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
|
||||||
|
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
patch := &proto.ProxyPatch{
|
||||||
|
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
|
||||||
|
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
|
||||||
|
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
|
||||||
|
Routes: networkmap.ToProtocolRoutes(nm.Routes),
|
||||||
|
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
|
||||||
|
}
|
||||||
|
if len(nm.ForwardingRules) > 0 {
|
||||||
|
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
|
||||||
|
for _, r := range nm.ForwardingRules {
|
||||||
|
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return patch
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
|
||||||
|
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
|
||||||
|
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
|
||||||
|
// without this helper the field would be incorrectly false for any peer
|
||||||
|
// that's the destination of an SSH-enabling policy without having
|
||||||
|
// peer.SSHEnabled set locally.
|
||||||
|
//
|
||||||
|
// Mirrors the two activation paths Calculate() uses:
|
||||||
|
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
|
||||||
|
// destinations.
|
||||||
|
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
|
||||||
|
// destinations, AND the peer has SSHEnabled set locally — this is the
|
||||||
|
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
|
||||||
|
//
|
||||||
|
// The full SSH AuthorizedUsers map is still produced by the client when it
|
||||||
|
// runs Calculate() over the envelope.
|
||||||
|
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
|
||||||
|
if c == nil || peer == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
|
||||||
|
// exist in c.Peers, otherwise no rule applies to it.
|
||||||
|
if _, ok := c.Peers[peer.ID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, policy := range c.Policies {
|
||||||
|
if policy == nil || !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if ruleEnablesSSHForPeer(c, rule, peer) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
|
||||||
|
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
|
||||||
|
// peer itself has SSH enabled locally.
|
||||||
|
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
|
||||||
|
if rule == nil || !rule.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !peerInDestinations(c, rule, peer.ID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerInDestinations reports whether peerID is in any of rule.Destinations'
|
||||||
|
// groups (or matches DestinationResource if it's a peer-typed resource —
|
||||||
|
// for non-peer types Calculate falls through to group lookup, so we mirror
|
||||||
|
// that exactly to avoid silent divergence).
|
||||||
|
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
|
||||||
|
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||||
|
return rule.DestinationResource.ID == peerID
|
||||||
|
}
|
||||||
|
for _, groupID := range rule.Destinations {
|
||||||
|
if c.IsPeerInGroup(peerID, groupID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,184 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
|
||||||
|
// explicit NetbirdSSH protocol, and the legacy implicit case where a
|
||||||
|
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
|
||||||
|
// the destination peer has SSHEnabled=true locally.
|
||||||
|
func TestComputeSSHEnabledForPeer(t *testing.T) {
|
||||||
|
const targetPeerID = "target"
|
||||||
|
const targetGroupID = "g_dst"
|
||||||
|
|
||||||
|
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
|
||||||
|
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
|
||||||
|
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
|
||||||
|
return &types.NetworkMapComponents{
|
||||||
|
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
|
||||||
|
Groups: map[string]*types.Group{targetGroupID: group},
|
||||||
|
Policies: []*types.Policy{{
|
||||||
|
ID: "p",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{rule},
|
||||||
|
}},
|
||||||
|
}, peer
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
peerSSH bool
|
||||||
|
rule types.PolicyRule
|
||||||
|
wantEnabled bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-tcp-22-with-peer-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-tcp-22-without-peer-ssh-disabled",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-tcp-22022-with-peer-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-all-protocol-with-peer-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit-port-range-covers-22",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp-80-no-ssh",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled-rule-skipped",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{targetGroupID},
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer-not-in-destinations",
|
||||||
|
peerSSH: true,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{"g_other"}, // target not in this group
|
||||||
|
},
|
||||||
|
wantEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer-typed-destination-resource-matches",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-peer-destination-resource-falls-through-to-groups",
|
||||||
|
peerSSH: false,
|
||||||
|
rule: types.PolicyRule{
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
|
||||||
|
Destinations: []string{targetGroupID}, // saved by group fallback
|
||||||
|
},
|
||||||
|
wantEnabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, peer := mkComponents(&tc.rule, tc.peerSSH)
|
||||||
|
got := computeSSHEnabledForPeer(c, peer)
|
||||||
|
assert.Equal(t, tc.wantEnabled, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
|
||||||
|
// belt-and-suspenders presence guard mirroring Calculate's
|
||||||
|
// getAllPeersFromGroups invariant.
|
||||||
|
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
|
||||||
|
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
|
||||||
|
c := &types.NetworkMapComponents{
|
||||||
|
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
|
||||||
|
Groups: map[string]*types.Group{
|
||||||
|
"g": {ID: "g", Peers: []string{"missing"}},
|
||||||
|
},
|
||||||
|
Policies: []*types.Policy{{
|
||||||
|
ID: "p", Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||||
|
Destinations: []string{"g"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
assert.False(t, computeSSHEnabledForPeer(c, peer),
|
||||||
|
"missing target peer must short-circuit to false, not consult policies")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
|
||||||
|
// function entry — Calculate doesn't accept nil either, but the helper is
|
||||||
|
// exported indirectly via ToComponentSyncResponse and may receive nil
|
||||||
|
// components on graceful-degrade paths.
|
||||||
|
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
|
||||||
|
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
|
||||||
|
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
|
||||||
|
}
|
||||||
@@ -6,24 +6,22 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
goproto "google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
nbroute "github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
"github.com/netbirdio/netbird/shared/sshauth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||||
@@ -138,8 +136,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
Routes: toProtocolRoutes(networkMap.Routes),
|
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
|
||||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
},
|
},
|
||||||
Checks: toProtocolChecks(ctx, checks),
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
@@ -152,19 +150,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||||
|
|
||||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||||
response.RemotePeers = remotePeers
|
response.RemotePeers = remotePeers
|
||||||
response.NetworkMap.RemotePeers = remotePeers
|
response.NetworkMap.RemotePeers = remotePeers
|
||||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||||
|
|
||||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||||
|
|
||||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||||
response.NetworkMap.FirewallRules = firewallRules
|
response.NetworkMap.FirewallRules = firewallRules
|
||||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||||
|
|
||||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||||
|
|
||||||
@@ -177,7 +175,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
}
|
}
|
||||||
|
|
||||||
if networkMap.AuthorizedUsers != nil {
|
if networkMap.AuthorizedUsers != nil {
|
||||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||||
userIDClaim := auth.DefaultUserIDClaim
|
userIDClaim := auth.DefaultUserIDClaim
|
||||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||||
userIDClaim = httpConfig.AuthUserIDClaim
|
userIDClaim = httpConfig.AuthUserIDClaim
|
||||||
@@ -185,79 +183,36 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// settings == nil → field stays nil → "no info in this snapshot", client
|
||||||
|
// preserves the deadline it already had. settings non-nil → emit either a
|
||||||
|
// valid deadline or the explicit-zero "disabled" sentinel via
|
||||||
|
// encodeSessionExpiresAt.
|
||||||
|
if settings != nil {
|
||||||
|
response.SessionExpiresAt = encodeSessionExpiresAt(
|
||||||
|
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
|
||||||
userIDToIndex := make(map[string]uint32)
|
// representation used on LoginResponse, SyncResponse and
|
||||||
var hashedUsers [][]byte
|
// ExtendAuthSessionResponse. See the proto comments on those messages.
|
||||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
//
|
||||||
|
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
|
||||||
for machineUser, users := range authorizedUsers {
|
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
|
||||||
indexes := make([]uint32, 0, len(users))
|
// its anchor.
|
||||||
for userID := range users {
|
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
|
||||||
idx, exists := userIDToIndex[userID]
|
// UTC deadline.
|
||||||
if !exists {
|
//
|
||||||
hash, err := sshauth.HashUserID(userID)
|
// Returning nil ("no info, preserve client's anchor") is the caller's job —
|
||||||
if err != nil {
|
// only meaningful on Sync builds where settings were not resolved.
|
||||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
|
||||||
continue
|
if deadline.IsZero() {
|
||||||
}
|
return ×tamppb.Timestamp{}
|
||||||
idx = uint32(len(hashedUsers))
|
|
||||||
userIDToIndex[userID] = idx
|
|
||||||
hashedUsers = append(hashedUsers, hash[:])
|
|
||||||
}
|
|
||||||
indexes = append(indexes, idx)
|
|
||||||
}
|
|
||||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
|
||||||
}
|
}
|
||||||
|
return timestamppb.New(deadline)
|
||||||
return hashedUsers, machineUsers
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
|
||||||
for _, rPeer := range peers {
|
|
||||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
|
||||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
|
||||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
|
||||||
}
|
|
||||||
dst = append(dst, &proto.RemotePeerConfig{
|
|
||||||
WgPubKey: rPeer.Key,
|
|
||||||
AllowedIps: allowedIPs,
|
|
||||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
|
||||||
Fqdn: rPeer.FQDN(dnsName),
|
|
||||||
AgentVersion: rPeer.Meta.WtVersion,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
|
||||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
|
||||||
protoUpdate := &proto.DNSConfig{
|
|
||||||
ServiceEnable: update.ServiceEnable,
|
|
||||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
|
||||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
|
||||||
ForwarderPort: forwardPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range update.CustomZones {
|
|
||||||
protoZone := convertToProtoCustomZone(zone)
|
|
||||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, nsGroup := range update.NameServerGroups {
|
|
||||||
cacheKey := nsGroup.ID
|
|
||||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
|
||||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
|
||||||
} else {
|
|
||||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
|
||||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
|
||||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return protoUpdate
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||||
@@ -277,204 +232,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
|
||||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
|
||||||
for _, r := range routes {
|
|
||||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
|
||||||
}
|
|
||||||
return protoRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolRoute(route *nbroute.Route) *proto.Route {
|
|
||||||
return &proto.Route{
|
|
||||||
ID: string(route.ID),
|
|
||||||
NetID: string(route.NetID),
|
|
||||||
Network: route.Network.String(),
|
|
||||||
Domains: route.Domains.ToPunycodeList(),
|
|
||||||
NetworkType: int64(route.NetworkType),
|
|
||||||
Peer: route.Peer,
|
|
||||||
Metric: int64(route.Metric),
|
|
||||||
Masquerade: route.Masquerade,
|
|
||||||
KeepRoute: route.KeepRoute,
|
|
||||||
SkipAutoApply: route.SkipAutoApply,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
|
||||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
|
|
||||||
// alongside the deprecated PeerIP for forward compatibility.
|
|
||||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
|
|
||||||
// when includeIPv6 is true.
|
|
||||||
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
|
||||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
|
||||||
for i := range rules {
|
|
||||||
rule := rules[i]
|
|
||||||
|
|
||||||
fwRule := &proto.FirewallRule{
|
|
||||||
PolicyID: []byte(rule.PolicyID),
|
|
||||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
|
||||||
Direction: getProtoDirection(rule.Direction),
|
|
||||||
Action: getProtoAction(rule.Action),
|
|
||||||
Protocol: getProtoProtocol(rule.Protocol),
|
|
||||||
Port: rule.Port,
|
|
||||||
}
|
|
||||||
|
|
||||||
if useSourcePrefixes && rule.PeerIP != "" {
|
|
||||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldUsePortRange(fwRule) {
|
|
||||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result, fwRule)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
|
||||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
|
||||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
|
||||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !addr.IsUnspecified() {
|
|
||||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv4Unspecified/0 is always valid, error is impossible.
|
|
||||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
|
||||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
|
||||||
|
|
||||||
if !includeIPv6 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
|
||||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
|
||||||
// IPv6Unspecified/0 is always valid, error is impossible.
|
|
||||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
|
||||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
|
||||||
if shouldUsePortRange(v6Rule) {
|
|
||||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
|
||||||
}
|
|
||||||
return []*proto.FirewallRule{v6Rule}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
|
||||||
func getProtoDirection(direction int) proto.RuleDirection {
|
|
||||||
if direction == types.FirewallRuleDirectionOUT {
|
|
||||||
return proto.RuleDirection_OUT
|
|
||||||
}
|
|
||||||
return proto.RuleDirection_IN
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
|
||||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
|
||||||
for i := range rules {
|
|
||||||
rule := rules[i]
|
|
||||||
result[i] = &proto.RouteFirewallRule{
|
|
||||||
SourceRanges: rule.SourceRanges,
|
|
||||||
Action: getProtoAction(rule.Action),
|
|
||||||
Destination: rule.Destination,
|
|
||||||
Protocol: getProtoProtocol(rule.Protocol),
|
|
||||||
PortInfo: getProtoPortInfo(rule),
|
|
||||||
IsDynamic: rule.IsDynamic,
|
|
||||||
Domains: rule.Domains.ToPunycodeList(),
|
|
||||||
PolicyID: []byte(rule.PolicyID),
|
|
||||||
RouteID: string(rule.RouteID),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoAction converts the action to proto.RuleAction.
|
|
||||||
func getProtoAction(action string) proto.RuleAction {
|
|
||||||
if action == string(types.PolicyTrafficActionDrop) {
|
|
||||||
return proto.RuleAction_DROP
|
|
||||||
}
|
|
||||||
return proto.RuleAction_ACCEPT
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
|
||||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
|
||||||
switch types.PolicyRuleProtocolType(protocol) {
|
|
||||||
case types.PolicyRuleProtocolALL:
|
|
||||||
return proto.RuleProtocol_ALL
|
|
||||||
case types.PolicyRuleProtocolTCP:
|
|
||||||
return proto.RuleProtocol_TCP
|
|
||||||
case types.PolicyRuleProtocolUDP:
|
|
||||||
return proto.RuleProtocol_UDP
|
|
||||||
case types.PolicyRuleProtocolICMP:
|
|
||||||
return proto.RuleProtocol_ICMP
|
|
||||||
default:
|
|
||||||
return proto.RuleProtocol_UNKNOWN
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
|
||||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
|
||||||
var portInfo proto.PortInfo
|
|
||||||
if rule.Port != 0 {
|
|
||||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
|
||||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
|
||||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
|
||||||
Range: &proto.PortInfo_Range{
|
|
||||||
Start: uint32(portRange.Start),
|
|
||||||
End: uint32(portRange.End),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &portInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
|
||||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
|
||||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
|
||||||
protoZone := &proto.CustomZone{
|
|
||||||
Domain: zone.Domain,
|
|
||||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
|
||||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
|
||||||
NonAuthoritative: zone.NonAuthoritative,
|
|
||||||
}
|
|
||||||
for _, record := range zone.Records {
|
|
||||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
|
||||||
Name: record.Name,
|
|
||||||
Type: int64(record.Type),
|
|
||||||
Class: record.Class,
|
|
||||||
TTL: int64(record.TTL),
|
|
||||||
RData: record.RData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return protoZone
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
|
||||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
|
||||||
protoGroup := &proto.NameServerGroup{
|
|
||||||
Primary: nsGroup.Primary,
|
|
||||||
Domains: nsGroup.Domains,
|
|
||||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
|
||||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
|
||||||
}
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
|
||||||
IP: ns.IP.String(),
|
|
||||||
Port: int64(ns.Port),
|
|
||||||
NSType: int64(ns.NSType),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return protoGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||||
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
||||||
if config == nil || config.AuthAudience == "" {
|
if config == nil || config.AuthAudience == "" {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||||
@@ -61,13 +63,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First run with config1
|
// First run with config1
|
||||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||||
|
|
||||||
// Second run with config2
|
// Second run with config2
|
||||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||||
|
|
||||||
// Third run with config1 again
|
// Third run with config1 again
|
||||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||||
|
|
||||||
// Verify that result1 and result3 are identical
|
// Verify that result1 and result3 are identical
|
||||||
if !reflect.DeepEqual(result1, result3) {
|
if !reflect.DeepEqual(result1, result3) {
|
||||||
@@ -99,7 +101,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -107,7 +109,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
cache := &cache.DNSConfigCache{}
|
cache := &cache.DNSConfigCache{}
|
||||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -200,3 +202,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestEncodeSessionExpiresAt pins the wire encoding the client's
|
||||||
|
// applySessionDeadline depends on:
|
||||||
|
//
|
||||||
|
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
|
||||||
|
// "expiry disabled or peer is not SSO-tracked" sentinel.
|
||||||
|
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
|
||||||
|
//
|
||||||
|
// The third state (nil pointer = "no info in this snapshot") is the caller's
|
||||||
|
// responsibility on the Sync path when settings could not be resolved; the
|
||||||
|
// helper itself never returns nil.
|
||||||
|
func TestEncodeSessionExpiresAt(t *testing.T) {
|
||||||
|
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
|
||||||
|
got := encodeSessionExpiresAt(time.Time{})
|
||||||
|
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
|
||||||
|
assert.Equal(t, int64(0), got.GetSeconds())
|
||||||
|
assert.Equal(t, int32(0), got.GetNanos())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-zero deadline round-trips", func(t *testing.T) {
|
||||||
|
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||||
|
got := encodeSessionExpiresAt(deadline)
|
||||||
|
assert.NotNil(t, got)
|
||||||
|
assert.True(t, got.AsTime().Equal(deadline))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -351,6 +351,7 @@ func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params
|
|||||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||||
RequireSubdomain: c.RequireSubdomain,
|
RequireSubdomain: c.RequireSubdomain,
|
||||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||||
|
Private: c.Private,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -754,6 +755,11 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
|||||||
InitialSyncComplete: update.InitialSyncComplete,
|
InitialSyncComplete: update.InitialSyncComplete,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
|
||||||
|
connUpdate = filterMappingsForProxy(conn, connUpdate)
|
||||||
|
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||||
@@ -882,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
|
||||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
// Private mappings require SupportsPrivateService; custom-port L4 mappings
|
||||||
// mappings with a custom listen port, since they don't understand the
|
// require SupportsCustomPorts. Remove operations always pass so proxies can
|
||||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
// clean up.
|
||||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
|
||||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
|
||||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if mapping.GetPrivate() {
|
||||||
|
caps := conn.capabilities
|
||||||
|
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -900,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
|||||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterMappingsForProxy drops mappings the proxy cannot safely receive
|
||||||
|
// (e.g. private mappings to a proxy without SupportsPrivateService).
|
||||||
|
// Returns the input unchanged when no filtering is needed.
|
||||||
|
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||||
|
if update == nil || len(update.Mapping) == 0 {
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||||
|
for _, m := range update.Mapping {
|
||||||
|
if !proxyAcceptsMapping(conn, m) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, m)
|
||||||
|
}
|
||||||
|
if len(kept) == len(update.Mapping) {
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
return &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: kept,
|
||||||
|
InitialSyncComplete: update.InitialSyncComplete,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
// create/update operations. For delete operations the original mapping is
|
// create/update operations. For delete operations the original mapping is
|
||||||
// used unchanged because proxies do not need to authenticate for removal.
|
// used unchanged because proxies do not need to authenticate for removal.
|
||||||
@@ -961,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
|||||||
|
|
||||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||||
|
|
||||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
|
||||||
|
// secrets and have no user-level group context, so groups stay nil. Email
|
||||||
|
// is also empty — these schemes don't resolve a user record at sign time.
|
||||||
|
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1050,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
|
||||||
if !authenticated || service.SessionPrivateKey == "" {
|
if !authenticated || service.SessionPrivateKey == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@@ -1058,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
|||||||
token, err := sessionkey.SignToken(
|
token, err := sessionkey.SignToken(
|
||||||
service.SessionPrivateKey,
|
service.SessionPrivateKey,
|
||||||
userId,
|
userId,
|
||||||
|
userEmail,
|
||||||
service.Domain,
|
service.Domain,
|
||||||
method,
|
method,
|
||||||
|
groupIDs,
|
||||||
|
groupNames,
|
||||||
proxyauth.DefaultSessionExpiry,
|
proxyauth.DefaultSessionExpiry,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1070,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
|
||||||
|
// into parallel id and name slices. ids[i] and names[i] always pair to
|
||||||
|
// the same group. nil entries (orphan ids the manager couldn't resolve)
|
||||||
|
// are skipped so the consumer can rely on positional pairing.
|
||||||
|
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
|
||||||
|
if len(groups) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
ids = make([]string, 0, len(groups))
|
||||||
|
names = make([]string, 0, len(groups))
|
||||||
|
for _, g := range groups {
|
||||||
|
if g == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids = append(ids, g.ID)
|
||||||
|
names = append(names, g.Name)
|
||||||
|
}
|
||||||
|
return ids, names
|
||||||
|
}
|
||||||
|
|
||||||
// SendStatusUpdate handles status updates from proxy clients.
|
// SendStatusUpdate handles status updates from proxy clients.
|
||||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||||
@@ -1334,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
|||||||
return verifier, redirectURL, nil
|
return verifier, redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
// GenerateSessionToken creates a signed session JWT for the given domain and
|
||||||
|
// user. The user's group memberships are embedded in the token so policy-aware
|
||||||
|
// middlewares on the proxy can authorise without an extra management round-trip.
|
||||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||||
service, err := s.getServiceByDomain(ctx, domain)
|
service, err := s.getServiceByDomain(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1345,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
|
|||||||
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
email string
|
||||||
|
groupIDs []string
|
||||||
|
groupNames []string
|
||||||
|
)
|
||||||
|
if s.usersManager != nil {
|
||||||
|
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||||
|
if uerr != nil {
|
||||||
|
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
|
||||||
|
} else if user != nil {
|
||||||
|
email = user.Email
|
||||||
|
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return sessionkey.SignToken(
|
return sessionkey.SignToken(
|
||||||
service.SessionPrivateKey,
|
service.SessionPrivateKey,
|
||||||
userID,
|
userID,
|
||||||
|
email,
|
||||||
domain,
|
domain,
|
||||||
method,
|
method,
|
||||||
|
groupIDs,
|
||||||
|
groupNames,
|
||||||
proxyauth.DefaultSessionExpiry,
|
proxyauth.DefaultSessionExpiry,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -1453,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"domain": domain,
|
"domain": domain,
|
||||||
@@ -1466,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.usersManager.GetUser(ctx, userID)
|
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"domain": domain,
|
"domain": domain,
|
||||||
@@ -1500,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
"user_id": userID,
|
"user_id": userID,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
}).Debug("ValidateSession: access denied")
|
}).Debug("ValidateSession: access denied")
|
||||||
|
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||||
//nolint:nilerr
|
//nolint:nilerr
|
||||||
return &proto.ValidateSessionResponse{
|
return &proto.ValidateSessionResponse{
|
||||||
Valid: false,
|
Valid: false,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
UserEmail: user.Email,
|
UserEmail: user.Email,
|
||||||
DeniedReason: "not_in_group",
|
DeniedReason: "not_in_group",
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1515,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
"email": user.Email,
|
"email": user.Email,
|
||||||
}).Debug("ValidateSession: access granted")
|
}).Debug("ValidateSession: access granted")
|
||||||
|
|
||||||
|
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||||
return &proto.ValidateSessionResponse{
|
return &proto.ValidateSessionResponse{
|
||||||
Valid: true,
|
Valid: true,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
UserEmail: user.Email,
|
UserEmail: user.Email,
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1551,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ptr[T any](v T) *T { return &v }
|
func ptr[T any](v T) *T { return &v }
|
||||||
|
|
||||||
|
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||||
|
// checks the peer's group membership against the service's access groups.
|
||||||
|
// Peers without a user (machine agents, automation workloads) are first-class
|
||||||
|
// callers; authorisation runs off peer-group memberships rather than the
|
||||||
|
// optional owning user's auto-groups. On success a session JWT is minted so
|
||||||
|
// the proxy can install a cookie and skip subsequent management round-trips.
|
||||||
|
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||||
|
domain := req.GetDomain()
|
||||||
|
tunnelIPStr := req.GetTunnelIp()
|
||||||
|
|
||||||
|
if domain == "" || tunnelIPStr == "" {
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "missing domain or tunnel_ip",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelIP := net.ParseIP(tunnelIPStr)
|
||||||
|
if tunnelIP == nil {
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "invalid_tunnel_ip",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := s.getServiceByDomain(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "service_not_found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
|
||||||
|
// validate and mint session cookies for their own account's domains.
|
||||||
|
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
|
||||||
|
if err != nil || peer == nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "peer_not_found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
DeniedReason: "peer_not_found",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||||
|
|
||||||
|
// Resolve the principal: when the peer is linked to a user, the human
|
||||||
|
// is the principal so multiple peers owned by the same user share a
|
||||||
|
// single identity. Unlinked peers (machine agents) are their own
|
||||||
|
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||||
|
// tag spend with — user.Email when linked, peer.Name when not.
|
||||||
|
principalID := peer.ID
|
||||||
|
displayIdentity := peer.Name
|
||||||
|
if peer.UserID != "" {
|
||||||
|
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||||
|
principalID = user.Id
|
||||||
|
if user.Email != "" {
|
||||||
|
displayIdentity = user.Email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||||
|
//nolint:nilerr
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: false,
|
||||||
|
UserId: principalID,
|
||||||
|
UserEmail: displayIdentity,
|
||||||
|
DeniedReason: "not_in_group",
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"tunnel_ip": tunnelIPStr,
|
||||||
|
"peer_id": peer.ID,
|
||||||
|
"principal_id": principalID,
|
||||||
|
}).Debug("ValidateTunnelPeer: access granted")
|
||||||
|
|
||||||
|
return &proto.ValidateTunnelPeerResponse{
|
||||||
|
Valid: true,
|
||||||
|
UserId: principalID,
|
||||||
|
UserEmail: displayIdentity,
|
||||||
|
SessionToken: token,
|
||||||
|
PeerGroupIds: groupIDs,
|
||||||
|
PeerGroupNames: groupNames,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||||
|
// groups. Private services authorise against AccessGroups (empty list fails
|
||||||
|
// closed — Validate() rejects that at save time but the RPC is the security
|
||||||
|
// boundary and must not trust upstream state). Bearer-auth services authorise
|
||||||
|
// against DistributionGroups when populated. Non-private non-bearer services
|
||||||
|
// are open.
|
||||||
|
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
|
||||||
|
if service.Private {
|
||||||
|
if len(service.AccessGroups) == 0 {
|
||||||
|
return fmt.Errorf("private service has no access groups")
|
||||||
|
}
|
||||||
|
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
|
||||||
|
}
|
||||||
|
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
|
||||||
|
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
|
||||||
|
// else a non-nil error.
|
||||||
|
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
|
||||||
|
if len(allowedGroups) == 0 {
|
||||||
|
return fmt.Errorf("no allowed groups configured")
|
||||||
|
}
|
||||||
|
allowed := make(map[string]struct{}, len(allowedGroups))
|
||||||
|
for _, g := range allowedGroups {
|
||||||
|
allowed[g] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, g := range peerGroupIDs {
|
||||||
|
if _, ok := allowed[g]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("peer not in allowed groups")
|
||||||
|
}
|
||||||
|
|||||||
@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||||
|
user, err := m.GetUser(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return user, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateUserGroupAccess(t *testing.T) {
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCheckPeerGroupAccess(t *testing.T) {
|
||||||
|
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
|
||||||
|
svc := &service.Service{Private: true, AccessGroups: nil}
|
||||||
|
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no access groups")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
|
||||||
|
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
|
||||||
|
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
|
||||||
|
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
|
||||||
|
svc := &service.Service{
|
||||||
|
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
|
||||||
|
}
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
|
||||||
|
svc := &service.Service{
|
||||||
|
Auth: service.AuthConfig{
|
||||||
|
BearerAuth: &service.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"grp-allowed"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
|
||||||
|
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-private non-bearer is open", func(t *testing.T) {
|
||||||
|
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||||
if debouncer.ProcessUpdate(update) {
|
if debouncer.ProcessUpdate(update) {
|
||||||
// Send immediately (first update or after quiet period)
|
// Send immediately (first update or after quiet period)
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
|
||||||
|
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
|
||||||
|
// stays up; no network map sync is performed. The new deadline is returned
|
||||||
|
// in ExtendAuthSessionResponse.SessionExpiresAt.
|
||||||
|
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
extendReq := &proto.ExtendAuthSessionRequest{}
|
||||||
|
peerKey, err := s.parseRequest(ctx, req, extendReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||||
|
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwt := extendReq.GetJwtToken()
|
||||||
|
if jwt == "" {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
const attempts = 3
|
||||||
|
for i := 0; i < attempts; i++ {
|
||||||
|
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == attempts-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
|
||||||
|
select {
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if userID == "" {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
|
||||||
|
return nil, mapError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success path normally returns a non-zero deadline. A defensive zero
|
||||||
|
// would still encode as the explicit "disabled" sentinel rather than nil,
|
||||||
|
// so the client clears any stale anchor instead of preserving it.
|
||||||
|
resp := &proto.ExtendAuthSessionResponse{
|
||||||
|
SessionExpiresAt: encodeSessionExpiresAt(deadline),
|
||||||
|
}
|
||||||
|
|
||||||
|
wgKey, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed processing request")
|
||||||
|
}
|
||||||
|
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed encrypting response")
|
||||||
|
}
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
WgPubKey: wgKey.PublicKey().String(),
|
||||||
|
Body: encrypted,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||||
var relayToken *Token
|
var relayToken *Token
|
||||||
var err error
|
var err error
|
||||||
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
|||||||
Checks: toProtocolChecks(ctx, postureChecks),
|
Checks: toProtocolChecks(ctx, postureChecks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// settings is always non-nil here, so we never emit nil — encoder returns
|
||||||
|
// either a valid deadline or the explicit-zero "disabled" sentinel.
|
||||||
|
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
|
||||||
|
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||||
|
)
|
||||||
|
|
||||||
return loginResp, nil
|
return loginResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -932,7 +1012,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
dnsName := s.networkMapController.GetDNSDomain(settings)
|
||||||
|
|
||||||
|
var plainResp *proto.SyncResponse
|
||||||
|
if s.networkMapController.PeerNeedsComponents(peer) {
|
||||||
|
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
|
||||||
|
// computed and recompute the raw components instead. This wastes one
|
||||||
|
// Calculate() call per initial-sync — the component-based wire
|
||||||
|
// format is what the peer actually consumes. The streaming path
|
||||||
|
// (network_map.Controller.UpdateAccountPeers) skips this duplication
|
||||||
|
// because it dispatches by capability before computing.
|
||||||
|
//
|
||||||
|
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
|
||||||
|
// interfaces to return PeerNetworkMapResult so the initial-sync path
|
||||||
|
// stops doing duplicate work. Deferred until the client-side
|
||||||
|
// decoder lands and there's a real deployment of capability=3 peers
|
||||||
|
// worth optimizing for.
|
||||||
|
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
|
||||||
|
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
|
||||||
|
}
|
||||||
|
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||||
|
} else {
|
||||||
|
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||||
|
}
|
||||||
|
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
|
|||||||
|
|
||||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
@@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
|
|||||||
assert.True(t, resp.Valid, "User should be allowed access")
|
assert.True(t, resp.Valid, "User should be allowed access")
|
||||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||||
assert.Empty(t, resp.DeniedReason)
|
assert.Empty(t, resp.DeniedReason)
|
||||||
|
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||||
@@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
|||||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||||
|
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||||
|
|||||||
@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||||
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
||||||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
|
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
|
||||||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
|
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
|
||||||
|
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
|
||||||
|
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||||
|
// Session deadline is derived from LastLogin + PeerLoginExpiration
|
||||||
|
// on every Login/Sync response. Without a fan-out push, connected
|
||||||
|
// peers keep the deadline they received at login time and only see
|
||||||
|
// the new value after the next unrelated NetworkMap change. Add
|
||||||
|
// these two fields to the trigger list so admin-side expiry tweaks
|
||||||
|
// (e.g. shortening from 24h to 1h) reach every connected peer
|
||||||
|
// within seconds, which is what the proactive-warning feature
|
||||||
|
// relies on (see client/internal/auth/sessionwatch).
|
||||||
updateAccountPeers = true
|
updateAccountPeers = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1621,6 +1631,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, g := range newGroupsToCreate {
|
||||||
|
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error allocating group seq id: %w", err)
|
||||||
|
}
|
||||||
|
g.AccountSeqID = seq
|
||||||
|
}
|
||||||
|
|
||||||
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||||
return fmt.Errorf("error saving groups: %w", err)
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ type Manager interface {
|
|||||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
|
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
|
|||||||
@@ -1304,6 +1304,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtendPeerSession mocks base method.
|
||||||
|
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
|
||||||
|
ret0, _ := ret[0].(time.Time)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
|
||||||
|
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// MarkPeerConnected mocks base method.
|
// MarkPeerConnected mocks base method.
|
||||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -3036,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||||
|
|
||||||
|
var newJWTGroup *types.Group
|
||||||
|
for _, g := range groups {
|
||||||
|
if g.Name == "group3" {
|
||||||
|
newJWTGroup = g
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
|
||||||
|
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||||
|
|||||||
@@ -240,6 +240,10 @@ const (
|
|||||||
AccountLocalMfaEnabled Activity = 123
|
AccountLocalMfaEnabled Activity = 123
|
||||||
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
|
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
|
||||||
AccountLocalMfaDisabled Activity = 124
|
AccountLocalMfaDisabled Activity = 124
|
||||||
|
// UserExtendedPeerSession indicates that a user refreshed their peer's
|
||||||
|
// SSO session deadline via ExtendAuthSession without re-establishing the
|
||||||
|
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
|
||||||
|
UserExtendedPeerSession Activity = 125
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
|
|||||||
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
|
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
|
||||||
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
|
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
|
||||||
|
|
||||||
|
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
|
||||||
|
|
||||||
DomainAdded: {"Domain added", "domain.add"},
|
DomainAdded: {"Domain added", "domain.add"},
|
||||||
DomainDeleted: {"Domain deleted", "domain.delete"},
|
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||||
DomainValidated: {"Domain validated", "domain.validate"},
|
DomainValidated: {"Domain validated", "domain.validate"},
|
||||||
|
|||||||
@@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
|
||||||
|
}
|
||||||
|
newGroup.AccountSeqID = seq
|
||||||
|
|
||||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||||
}
|
}
|
||||||
@@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||||
|
|
||||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
|||||||
|
|
||||||
newGroup.AccountID = accountID
|
newGroup.AccountID = accountID
|
||||||
|
|
||||||
|
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newGroup.AccountSeqID = seq
|
||||||
|
|
||||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
|||||||
|
|
||||||
newGroup.AccountID = accountID
|
newGroup.AccountID = accountID
|
||||||
|
|
||||||
|
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||||
|
|
||||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,15 +15,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
@@ -32,12 +30,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||||
|
|
||||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -56,17 +52,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
|
||||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiPrefix = "/api"
|
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
accountManager.GetUserFromUserAuth,
|
accountManager.GetUserFromUserAuth,
|
||||||
rateLimiter,
|
rateLimiter,
|
||||||
appMetrics.GetMeter(),
|
appMetrics.GetMeter(),
|
||||||
|
isValidChildAccount,
|
||||||
)
|
)
|
||||||
|
|
||||||
corsMiddleware := cors.AllowAll()
|
corsMiddleware := cors.AllowAll()
|
||||||
|
|
||||||
rootRouter := mux.NewRouter()
|
|
||||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||||
|
|
||||||
prefix := apiPrefix
|
|
||||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
|
||||||
|
|
||||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if embedded IdP is enabled for instance manager
|
|
||||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
|
||||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||||
}
|
}
|
||||||
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
oauthHandler.RegisterEndpoints(router)
|
oauthHandler.RegisterEndpoints(router)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mount embedded IdP handler at /oauth2 path if configured
|
return router, nil
|
||||||
if embeddedIdpEnabled {
|
|
||||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
|
||||||
}
|
|
||||||
|
|
||||||
return rootRouter, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
|||||||
|
|
||||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
|
|
||||||
|
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
authManager serverauth.Manager
|
authManager serverauth.Manager
|
||||||
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
|
|||||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||||
rateLimiter *APIRateLimiter
|
rateLimiter *APIRateLimiter
|
||||||
patUsageTracker *PATUsageTracker
|
patUsageTracker *PATUsageTracker
|
||||||
|
isValidChildAccount IsValidChildAccountFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
|
|||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
rateLimiter *APIRateLimiter,
|
rateLimiter *APIRateLimiter,
|
||||||
meter metric.Meter,
|
meter metric.Meter,
|
||||||
|
isValidChildAccount IsValidChildAccountFunc,
|
||||||
) *AuthMiddleware {
|
) *AuthMiddleware {
|
||||||
var patUsageTracker *PATUsageTracker
|
var patUsageTracker *PATUsageTracker
|
||||||
if meter != nil {
|
if meter != nil {
|
||||||
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
|
|||||||
getUserFromUserAuth: getUserFromUserAuth,
|
getUserFromUserAuth: getUserFromUserAuth,
|
||||||
rateLimiter: rateLimiter,
|
rateLimiter: rateLimiter,
|
||||||
patUsageTracker: patUsageTracker,
|
patUsageTracker: patUsageTracker,
|
||||||
|
isValidChildAccount: isValidChildAccount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||||
userAuth.AccountId = impersonate[0]
|
userAuth.AccountId = impersonate[0]
|
||||||
userAuth.IsChild = true
|
userAuth.IsChild = true
|
||||||
}
|
}
|
||||||
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||||
userAuth.AccountId = impersonate[0]
|
userAuth.AccountId = impersonate[0]
|
||||||
userAuth.IsChild = true
|
userAuth.IsChild = true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
},
|
},
|
||||||
disabledLimiter,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||||
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
},
|
},
|
||||||
disabledLimiter,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
|
func(_ context.Context, _, _, _ string) bool { return false },
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"go.opentelemetry.io/otel/metric/noop"
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||||
|
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||||
|
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package validator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IntegratedValidatorImpl struct{}
|
||||||
|
|
||||||
|
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
|
||||||
|
return &IntegratedValidatorImpl{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||||
|
return update, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
|
||||||
|
return peer.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
|
||||||
|
return false, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
|
||||||
|
validatedPeers := make(map[string]struct{})
|
||||||
|
for _, p := range peers {
|
||||||
|
validatedPeers[p.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
return validatedPeers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
|
||||||
|
return make(map[string]string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
|
||||||
|
return flowResponse
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
nbversion "github.com/netbirdio/netbird/version"
|
nbversion "github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
@@ -53,6 +54,7 @@ type DataSource interface {
|
|||||||
GetAllAccounts(ctx context.Context) []*types.Account
|
GetAllAccounts(ctx context.Context) []*types.Account
|
||||||
GetStoreEngine() types.Engine
|
GetStoreEngine() types.Engine
|
||||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||||
|
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnManager peer connection manager that holds state for current active connections
|
// ConnManager peer connection manager that holds state for current active connections
|
||||||
@@ -223,6 +225,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
servicesAuthPassword int
|
servicesAuthPassword int
|
||||||
servicesAuthPin int
|
servicesAuthPin int
|
||||||
servicesAuthOIDC int
|
servicesAuthOIDC int
|
||||||
|
// Private-service signals — track adoption of NetBird-only mode
|
||||||
|
// (services backed by an embedded proxy peer + access groups).
|
||||||
|
servicesPrivate int
|
||||||
|
servicesPrivateWithGroups int
|
||||||
|
servicesPrivateAccessGroupsSum int
|
||||||
|
servicesWithDirectUpstream int
|
||||||
)
|
)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
metricsProperties := make(properties)
|
metricsProperties := make(properties)
|
||||||
@@ -380,9 +388,31 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||||
servicesAuthOIDC++
|
servicesAuthOIDC++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if service.Private {
|
||||||
|
servicesPrivate++
|
||||||
|
if len(service.AccessGroups) > 0 {
|
||||||
|
servicesPrivateWithGroups++
|
||||||
|
}
|
||||||
|
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range service.Targets {
|
||||||
|
if target.Options.DirectUpstream {
|
||||||
|
servicesWithDirectUpstream++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Proxy / BYOP cluster signals come from the proxies table aggregated
|
||||||
|
// across all accounts in a single store query; nil on FileStore.
|
||||||
|
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
||||||
metricsProperties["uptime"] = uptime
|
metricsProperties["uptime"] = uptime
|
||||||
metricsProperties["accounts"] = accounts
|
metricsProperties["accounts"] = accounts
|
||||||
@@ -430,6 +460,15 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
metricsProperties["services_auth_password"] = servicesAuthPassword
|
metricsProperties["services_auth_password"] = servicesAuthPassword
|
||||||
metricsProperties["services_auth_pin"] = servicesAuthPin
|
metricsProperties["services_auth_pin"] = servicesAuthPin
|
||||||
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
||||||
|
metricsProperties["services_private"] = servicesPrivate
|
||||||
|
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
|
||||||
|
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
|
||||||
|
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
|
||||||
|
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
|
||||||
|
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
|
||||||
|
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
|
||||||
|
metricsProperties["proxies"] = proxyMetrics.Proxies
|
||||||
|
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
|
||||||
metricsProperties["custom_domains"] = customDomains
|
metricsProperties["custom_domains"] = customDomains
|
||||||
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -123,7 +124,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Targets: []*rpservice.Target{
|
Targets: []*rpservice.Target{
|
||||||
{TargetType: "peer"},
|
{TargetType: "peer"},
|
||||||
{TargetType: "host"},
|
{TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||||
},
|
},
|
||||||
Auth: rpservice.AuthConfig{
|
Auth: rpservice.AuthConfig{
|
||||||
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||||
@@ -141,6 +142,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
|||||||
},
|
},
|
||||||
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "svc3-private",
|
||||||
|
Enabled: true,
|
||||||
|
Private: true,
|
||||||
|
AccessGroups: []string{"grp-eng", "grp-ops"},
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||||
|
},
|
||||||
|
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -254,6 +265,18 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e
|
|||||||
return 3, 2, nil
|
return 3, 2, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyMetrics returns canned proxy/cluster counts so the
|
||||||
|
// generateProperties test can assert the BYOP signals end-to-end.
|
||||||
|
func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) {
|
||||||
|
return store.ProxyMetrics{
|
||||||
|
Clusters: 3,
|
||||||
|
ClustersBYOP: 1,
|
||||||
|
ClustersPrivate: 1,
|
||||||
|
Proxies: 4,
|
||||||
|
ProxiesConnected: 2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||||
func TestGenerateProperties(t *testing.T) {
|
func TestGenerateProperties(t *testing.T) {
|
||||||
ds := mockDatasource{}
|
ds := mockDatasource{}
|
||||||
@@ -393,17 +416,17 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||||
}
|
}
|
||||||
|
|
||||||
if properties["services"] != 2 {
|
if properties["services"] != 3 {
|
||||||
t.Errorf("expected 2 services, got %v", properties["services"])
|
t.Errorf("expected 3 services, got %v", properties["services"])
|
||||||
}
|
}
|
||||||
if properties["services_enabled"] != 1 {
|
if properties["services_enabled"] != 2 {
|
||||||
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
|
t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"])
|
||||||
}
|
}
|
||||||
if properties["services_targets"] != 3 {
|
if properties["services_targets"] != 4 {
|
||||||
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
|
t.Errorf("expected 4 services_targets, got %v", properties["services_targets"])
|
||||||
}
|
}
|
||||||
if properties["services_status_active"] != 1 {
|
if properties["services_status_active"] != 2 {
|
||||||
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
|
t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"])
|
||||||
}
|
}
|
||||||
if properties["services_status_pending"] != 1 {
|
if properties["services_status_pending"] != 1 {
|
||||||
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
||||||
@@ -420,6 +443,9 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
if properties["services_target_type_domain"] != 1 {
|
if properties["services_target_type_domain"] != 1 {
|
||||||
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
||||||
}
|
}
|
||||||
|
if properties["services_target_type_cluster"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"])
|
||||||
|
}
|
||||||
if properties["services_auth_password"] != 1 {
|
if properties["services_auth_password"] != 1 {
|
||||||
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
||||||
}
|
}
|
||||||
@@ -429,6 +455,33 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
if properties["services_auth_pin"] != 0 {
|
if properties["services_auth_pin"] != 0 {
|
||||||
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
||||||
}
|
}
|
||||||
|
if properties["services_private"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_private, got %v", properties["services_private"])
|
||||||
|
}
|
||||||
|
if properties["services_private_with_access_groups"] != 1 {
|
||||||
|
t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"])
|
||||||
|
}
|
||||||
|
if properties["services_private_access_groups_sum"] != 2 {
|
||||||
|
t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"])
|
||||||
|
}
|
||||||
|
if properties["services_with_direct_upstream"] != 2 {
|
||||||
|
t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"])
|
||||||
|
}
|
||||||
|
if properties["proxy_clusters"] != int64(3) {
|
||||||
|
t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"])
|
||||||
|
}
|
||||||
|
if properties["proxy_clusters_byop"] != int64(1) {
|
||||||
|
t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"])
|
||||||
|
}
|
||||||
|
if properties["proxy_clusters_private"] != int64(1) {
|
||||||
|
t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"])
|
||||||
|
}
|
||||||
|
if properties["proxies"] != int64(4) {
|
||||||
|
t.Errorf("expected 4 proxies, got %v", properties["proxies"])
|
||||||
|
}
|
||||||
|
if properties["proxies_connected"] != int64(2) {
|
||||||
|
t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"])
|
||||||
|
}
|
||||||
if properties["custom_domains"] != int64(3) {
|
if properties["custom_domains"] != int64(3) {
|
||||||
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
||||||
}
|
}
|
||||||
|
|||||||
156
management/server/migration/account_seq.go
Normal file
156
management/server/migration/account_seq.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package migration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
|
||||||
|
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
|
||||||
|
// with the next free id per account. Idempotent: safe to re-run; both steps
|
||||||
|
// no-op once everything is consistent.
|
||||||
|
//
|
||||||
|
// Implemented as two table-wide SQL statements with window functions, one
|
||||||
|
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
|
||||||
|
// well under a second instead of the per-account-loop ~2 minutes.
|
||||||
|
//
|
||||||
|
// orderColumn is the column to use when assigning the deterministic ordering
|
||||||
|
// (typically the primary-key string id).
|
||||||
|
func BackfillAccountSeqIDs[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
db *gorm.DB,
|
||||||
|
entity types.AccountSeqEntity,
|
||||||
|
orderColumn string,
|
||||||
|
) error {
|
||||||
|
var model T
|
||||||
|
if !db.Migrator().HasTable(&model) {
|
||||||
|
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := &gorm.Statement{DB: db}
|
||||||
|
if err := stmt.Parse(&model); err != nil {
|
||||||
|
return fmt.Errorf("parse model: %w", err)
|
||||||
|
}
|
||||||
|
table := quoteIdent(db, stmt.Schema.Table)
|
||||||
|
orderCol := quoteIdent(db, orderColumn)
|
||||||
|
|
||||||
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
var pending int64
|
||||||
|
if err := tx.Raw(
|
||||||
|
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
|
||||||
|
).Scan(&pending).Error; err != nil {
|
||||||
|
return fmt.Errorf("count pending on %s: %w", table, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pending > 0 {
|
||||||
|
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
|
||||||
|
if err := backfillRankSQL(tx, table, orderCol); err != nil {
|
||||||
|
return fmt.Errorf("rank %s: %w", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := seedCountersSQL(tx, table, entity); err != nil {
|
||||||
|
return fmt.Errorf("seed counters for %s: %w", entity, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func quoteIdent(db *gorm.DB, name string) string {
|
||||||
|
switch db.Dialector.Name() {
|
||||||
|
case "mysql":
|
||||||
|
return "`" + name + "`"
|
||||||
|
case "postgres":
|
||||||
|
return `"` + name + `"`
|
||||||
|
default:
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
|
||||||
|
dialect := db.Dialector.Name()
|
||||||
|
var sql string
|
||||||
|
switch dialect {
|
||||||
|
case "postgres", "sqlite":
|
||||||
|
sql = fmt.Sprintf(`
|
||||||
|
WITH max_seq AS (
|
||||||
|
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||||
|
FROM %s
|
||||||
|
GROUP BY account_id
|
||||||
|
),
|
||||||
|
ranked AS (
|
||||||
|
SELECT p.id,
|
||||||
|
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
|
||||||
|
FROM %s p
|
||||||
|
JOIN max_seq m ON p.account_id = m.account_id
|
||||||
|
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
|
||||||
|
)
|
||||||
|
UPDATE %s SET account_seq_id = ranked.new_seq
|
||||||
|
FROM ranked
|
||||||
|
WHERE %s.id = ranked.id
|
||||||
|
`, table, orderCol, table, table, table)
|
||||||
|
case "mysql":
|
||||||
|
sql = fmt.Sprintf(`
|
||||||
|
UPDATE %s p
|
||||||
|
JOIN (
|
||||||
|
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||||
|
FROM %s
|
||||||
|
GROUP BY account_id
|
||||||
|
) m ON p.account_id = m.account_id
|
||||||
|
JOIN (
|
||||||
|
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
|
||||||
|
FROM %s
|
||||||
|
WHERE account_seq_id IS NULL OR account_seq_id = 0
|
||||||
|
) r ON p.id = r.id
|
||||||
|
SET p.account_seq_id = m.max_seq + r.rn
|
||||||
|
`, table, table, orderCol, table)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||||
|
}
|
||||||
|
return db.Exec(sql).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
|
||||||
|
dialect := db.Dialector.Name()
|
||||||
|
var sql string
|
||||||
|
switch dialect {
|
||||||
|
case "postgres":
|
||||||
|
sql = fmt.Sprintf(`
|
||||||
|
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||||
|
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||||
|
FROM %s
|
||||||
|
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||||
|
GROUP BY account_id
|
||||||
|
ON CONFLICT (account_id, entity) DO UPDATE
|
||||||
|
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||||
|
`, table)
|
||||||
|
case "sqlite":
|
||||||
|
sql = fmt.Sprintf(`
|
||||||
|
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||||
|
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||||
|
FROM %s
|
||||||
|
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||||
|
GROUP BY account_id
|
||||||
|
ON CONFLICT (account_id, entity) DO UPDATE
|
||||||
|
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||||
|
`, table)
|
||||||
|
case "mysql":
|
||||||
|
sql = fmt.Sprintf(`
|
||||||
|
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||||
|
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||||
|
FROM %s
|
||||||
|
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||||
|
GROUP BY account_id
|
||||||
|
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||||
|
`, table)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||||
|
}
|
||||||
|
return db.Exec(sql, string(entity)).Error
|
||||||
|
}
|
||||||
@@ -98,6 +98,7 @@ type MockAccountManager struct {
|
|||||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
|
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
|
||||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||||
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||||
@@ -860,6 +861,14 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
|
|||||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||||
|
if am.ExtendPeerSessionFunc != nil {
|
||||||
|
return am.ExtendPeerSessionFunc(ctx, peerPubKey, userID)
|
||||||
|
}
|
||||||
|
return time.Time{}, status.Errorf(codes.Unimplemented, "method ExtendPeerSession is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
if am.SyncPeerFunc != nil {
|
if am.SyncPeerFunc != nil {
|
||||||
|
|||||||
@@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newNSGroup.AccountSeqID = seq
|
||||||
|
|
||||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
|
||||||
|
|
||||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,9 +71,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
|||||||
|
|
||||||
network.ID = xid.New().String()
|
network.ID = xid.New().String()
|
||||||
|
|
||||||
err = m.store.SaveNetwork(ctx, network)
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to allocate network seq id: %w", err)
|
||||||
|
}
|
||||||
|
network.AccountSeqID = seq
|
||||||
|
|
||||||
|
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||||
|
return fmt.Errorf("failed to save network: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
|
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
|
||||||
@@ -102,14 +113,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get network: %w", err)
|
||||||
|
}
|
||||||
|
network.AccountSeqID = existing.AccountSeqID
|
||||||
|
|
||||||
|
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||||
|
return fmt.Errorf("failed to save network: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||||
|
|
||||||
return network, m.store.SaveNetwork(ctx, network)
|
return network, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user