Compare commits

...

21 Commits

Author SHA1 Message Date
Maycon Santos
67211010f7 [client, gui] fix exit nodes menu on reconnect, remove tooltips (#5167)
* [client, gui] fix exit nodes menu on reconnect

clean s.exitNodeStates when disconnecting

* disable tooltip for exit nodes and settings
2026-01-23 18:39:45 +01:00
Maycon Santos
c61568ceb4 [client] Change default rosenpass log level (#5137)
* Change default rosenpass log level

- Add support to environment configuration
- Change default log level to info

* use .String() for print log level
2026-01-23 18:06:54 +01:00
Vlad
737d6061bf [management] ephemeral peers track on login (#5165) 2026-01-23 18:05:22 +01:00
Zoltan Papp
ee3a67d2d8 [client] Fix/health result in bundle (#5164)
* Add support for optional status refresh callback during debug bundle generation

* Always update wg status

* Remove duplicated wg status call
2026-01-23 17:06:07 +01:00
Viktor Liu
1a32e4c223 [client] Fix IPv4-only in bind proxy (#5154) 2026-01-23 15:15:34 +01:00
Viktor Liu
269d5d1cba [client] Try next DNS upstream on SERVFAIL/REFUSED responses (#5163) 2026-01-23 11:59:52 +01:00
Bethuel Mmbaga
a1de2b8a98 [management] Move activity store encryption to shared crypt package (#5111) 2026-01-22 15:01:13 +03:00
Viktor Liu
d0221a3e72 [client] Add cpu profile to debug bundle (#4700) 2026-01-22 12:24:12 +01:00
Bethuel Mmbaga
8da23daae3 [management] Fix activity event initiator for user group changes (#5152) 2026-01-22 14:18:46 +03:00
Viktor Liu
f86022eace [client] Hide forwarding rules in status when count is zero (#5149)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 10:01:08 +01:00
Viktor Liu
ee54827f94 [client] Add IPv6 support to usersace bind (#5147) 2026-01-22 10:20:43 +08:00
Zoltan Papp
e908dea702 [client] Extend WG watcher for ICE connection too (#5133)
Extend WG watcher for ICE connection too
2026-01-21 10:42:13 +01:00
Maycon Santos
030650a905 [client] Fix RFC 4592 wildcard matching for existing domain names (#5145)
Per RFC 4592 section 2.2.1, wildcards should only match when the queried
name does not exist in the zone. Previously, if host.example.com had an
A record and *.example.com had an AAAA record, querying AAAA for
host.example.com would incorrectly return the wildcard AAAA instead of
NODATA.

Now the resolver checks if the domain exists (with any record type)
before falling back to wildcard matching, returning proper NODATA
responses for existing names without the requested record type.
2026-01-21 08:48:32 +01:00
Misha Bragin
e01998815e [infra] add embedded STUN to getting started (#5141) 2026-01-20 19:01:34 +01:00
Zoltan Papp
07e4a5a23c Fixes profile switching and repeated down/up command failures. (#5142)
When Down() and Up() are called in quick succession, the connectWithRetryRuns goroutine could set ErrResetConnection after Down() had cleared the state, causing the subsequent Up() to fail.

Fix by waiting for the goroutine to exit (via clientGiveUpChan) before Down() returns. Uses a 5-second timeout to prevent RPC timeouts while ensuring the goroutine completes in most cases.
2026-01-20 18:22:37 +01:00
Diego Romar
b3a2992a10 [client/android] - Fix Rosenpass connectivity for Android peers (#5044)
* [client] Add WGConfigurer interface

To allow Rosenpass to work both with kernel
WireGuard via wgctrl (default behavior) and
userspace WireGuard via IPC on Android/iOS
using WGUSPConfigurer

* [client] Remove Rosenpass debug logs

* [client] Return simpler peer configuration in outputKey method

ConfigureDevice, the method previously used in
outputKey via wgClient to update the device's
properties, is now defined in the WGConfigurer
interface and implemented both in kernel_unix and
usp configurers.

PresharedKey datatype was also changed from
boolean to [32]byte to compare it
to the original NetBird PSK, so that Rosenpass
may replace it with its own when necessary.

* [client] Remove unused field

* [client] Replace usage of WGConfigurer

Replaced with preshared key setter interface,
which only defines a method to set / update the preshared key.

Logic has been migrated from rosenpass/netbird_handler to client/iface.

* [client] Use same default peer keepalive value when setting preshared keys

* [client] Store PresharedKeySetter iface in rosenpass manager

To avoid no-op if SetInterface is called before generateConfig

* [client] Add mutex usage in rosenpass netbird handler

* [client] change implementation setting Rosenpass preshared key

Instead of providing a method to configure a device (device/interface.go),
it forwards the new parameters to the configurer (either
kernel_unix.go / usp.go).

This removes dependency on reading FullStats, and makes use of a common
method (buildPresharedKeyConfig in configurer/common.go) to build a
minimal WG config that only sets/updates the PSK.

netbird_handler.go now keeps s list of initializedPeers to choose whether
to set the value of "UpdateOnly" when calling iface.SetPresharedKey.

* [client] Address possible race condition

Between outputKey calls and peer removal; it
checks again if the peer still exists in the
peers map before inserting it in the
initializedPeers map.

* [client] Add psk Rosenpass-initialized check

On client/internal/peer/conn.go, the presharedKey
function would always return the current key
set in wgConfig.presharedKey.

This would eventually overwrite a key set
by Rosenpass if the feature is active.

The purpose here is to set a handler that will
check if a given peer has its psk initialized
by Rosenpass to skip updating the psk
via updatePeer (since it calls presharedKey
method in conn.go).

* Add missing updateOnly flag setup for usp peers

* Change common.go buildPresharedKeyConfig signature

PeerKey datatype changed from string to
wgTypes.Key. Callers are responsible for parsing
a peer key with string datatype.
2026-01-20 13:26:51 -03:00
Maycon Santos
202fa47f2b [client] Add support to wildcard custom records (#5125)
* **New Features**
  * Wildcard DNS fallback for eligible query types (excluding NS/SOA): attempts wildcard records when no exact match, rewrites wildcard names back to the original query, and rotates responses; preserves CNAME resolution.

* **Tests**
  * Vastly expanded coverage for wildcard behaviors, precedence, multi-record round‑robin, multi-type chains, multi-hop and cross-zone scenarios, and edge cases (NXDOMAIN/NODATA, fallthrough).

* **Chores**
  * CI lint config updated to ignore an additional codespell entry.
2026-01-20 17:21:25 +01:00
Misha Bragin
4888021ba6 Add missing activity events to the API response (#5140) 2026-01-20 15:12:22 +01:00
Misha Bragin
a0b0b664b6 Local user password change (embedded IdP) (#5132) 2026-01-20 14:16:42 +01:00
Diego Romar
50da5074e7 [client] change notifyDisconnected call (#5138)
On handleJobStream, when handling error codes 
from receiveJobRequest in the switch-case, 
notifying disconnected in cases where it isn't a 
disconnection breaks connection status reporting 
on mobile peers.

This commit changes it so it isn't called on
Canceled or Unimplemented status codes.
2026-01-20 07:14:33 -03:00
Zoltan Papp
58daa674ef [Management/Client] Trigger debug bundle runs from API/Dashboard (#4592) (#4832)
This PR adds the ability to trigger debug bundle generation remotely from the Management API/Dashboard.
2026-01-19 11:22:16 +01:00
116 changed files with 8488 additions and 2553 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
golangci:
strategy:

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -221,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
cpuProfilingStarted := false
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = true
defer func() {
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
}
}
}()
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = false
}
}
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -302,24 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
statusOutputString = overview.FullDetailSummary()
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -378,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config,
StatusRecorder: recorder,
SyncResponse: syncResponse,
LogFile: logFilePath,
LogPath: logFilePath,
CPUProfile: nil,
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
@@ -97,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -115,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -124,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {

View File

@@ -190,7 +190,7 @@ func (c *Client) Start(startCtx context.Context) error {
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()

View File

@@ -0,0 +1,169 @@
package bind
import (
"errors"
"net"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
// to the appropriate connection based on the destination address.
// ReadFrom is not used in the hot path - ICEBind receives packets via
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
readFromWarn sync.Once
}
// NewDualStackPacketConn creates a new dual-stack packet connection.
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from the available connection (preferring IPv4).
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
d.readFromWarn.Do(func() {
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
})
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type.
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections.
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection if available,
// otherwise the IPv6 connection.
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections.
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections.
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections.
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -0,0 +1,119 @@
package bind
import (
"net"
"testing"
)
var (
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
payload = make([]byte, 1200)
)
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = conn.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
ds := NewDualStackPacketConn(conn, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn.Close()
ds := NewDualStackPacketConn(nil, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
addrs := []net.Addr{ipv4Addr, ipv6Addr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, addrs[i&1])
}
}

View File

@@ -0,0 +1,191 @@
package bind
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
ipv4Conn := &mockPacketConn{network: "udp4"}
ipv6Conn := &mockPacketConn{network: "udp6"}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
tests := []struct {
name string
addr *net.UDPAddr
wantSocket string
}{
{
name: "IPv4 address",
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 address",
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantSocket: "udp6",
},
{
name: "IPv4-mapped IPv6 goes to IPv4",
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv4 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
wantSocket: "udp6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipv4Conn.writeCount = 0
ipv6Conn.writeCount = 0
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
require.NoError(t, err)
assert.Equal(t, 4, n)
if tt.wantSocket == "udp4" {
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
} else {
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
}
})
}
}
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
// IPv4 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.NoError(t, err)
// IPv6 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv6 connection")
}
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
// IPv6 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.NoError(t, err)
// IPv4 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv4 connection")
}
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
// only reads from one socket (IPv4 preferred). This is fine because the actual
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
ipv4Conn := &mockPacketConn{
network: "udp4",
readData: []byte("from ipv4"),
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
}
ipv6Conn := &mockPacketConn{
network: "udp6",
readData: []byte("from ipv6"),
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
buf := make([]byte, 100)
n, addr, err := dualStack.ReadFrom(buf)
require.NoError(t, err)
// reads from IPv4 (preferred) - this is expected behavior
assert.Equal(t, "from ipv4", string(buf[:n]))
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
}
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
tests := []struct {
name string
ipv4 net.PacketConn
ipv6 net.PacketConn
wantAddr net.Addr
}{
{
name: "both available returns IPv4",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv4Addr,
},
{
name: "IPv4 only",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: nil,
wantAddr: ipv4Addr,
},
{
name: "IPv6 only",
ipv4: nil,
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv6Addr,
},
{
name: "neither returns nil",
ipv4: nil,
ipv6: nil,
wantAddr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
})
}
}
// mock
type mockPacketConn struct {
network string
writeCount int
readData []byte
readAddr net.Addr
localAddr net.Addr
}
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if m.readData != nil {
return copy(b, m.readData), m.readAddr, nil
}
return 0, nil, nil
}
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
m.writeCount++
return len(b), nil
}
func (m *mockPacketConn) Close() error { return nil }
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@@ -14,7 +14,6 @@ import (
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -28,22 +27,7 @@ type receiverCreator struct {
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -73,6 +57,8 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
@@ -118,6 +104,12 @@ func (s *ICEBind) Close() error {
close(s.closedChan)
s.muUDPMux.Lock()
s.ipv4Conn = nil
s.ipv6Conn = nil
s.udpMux = nil
s.muUDPMux.Unlock()
return s.StdNetBind.Close()
}
@@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
// Detect IPv4 vs IPv6 from connection's local address
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
s.ipv4Conn = conn
} else {
s.ipv6Conn = conn
}
s.createOrUpdateMux()
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
@@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
//nolint:staticcheck
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
@@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
continue
}
sizes[i] = msg.N
@@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
}
// createOrUpdateMux creates or updates the UDP mux with the available connections.
// Must be called with muUDPMux held.
func (s *ICEBind) createOrUpdateMux() {
var muxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
muxConn = NewDualStackPacketConn(
nbnet.WrapPacketConn(s.ipv4Conn),
nbnet.WrapPacketConn(s.ipv6Conn),
)
case s.ipv4Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
case s.ipv6Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
default:
return
}
// Don't close the old mux - it doesn't own the underlying connections.
// The sockets are managed by WireGuard's StdNetBind, not by us.
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
@@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
s.muUDPMux.Lock()
mux := s.udpMux
s.muUDPMux.Unlock()
if mux != nil {
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
log.Warnf("failed to handle STUN packet: %v", muxErr)
}
}
buffers[i] = []byte{}

View File

@@ -0,0 +1,324 @@
package bind
import (
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, ipv6Conn := createDualStackConns(t)
defer ipv4Conn.Close()
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
pool := createMsgPool()
// Simulate wireguard-go calling CreateReceiverFn for IPv4
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
require.NotNil(t, ipv4RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
iceBind.muUDPMux.Unlock()
// Simulate wireguard-go calling CreateReceiverFn for IPv6
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
require.NotNil(t, ipv6RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer ipv4Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn)
assert.Nil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.Nil(t, iceBind.ipv4Conn)
assert.NotNil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
// with peers on different address families through the same DualStackPacketConn.
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
// two "remote peers" listening on different address families
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
// our local dual-stack connection
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
// send to both peers
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
require.NoError(t, err)
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
require.NoError(t, err)
// verify IPv4 peer got its packet from the IPv4 socket
buf := make([]byte, 100)
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := ipv4Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv4", string(buf[:n]))
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
// verify IPv6 peer got its packet from the IPv6 socket
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err = ipv6Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv6", string(buf[:n]))
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
}
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
const packetsPerFamily = 500
ipv4Received := make(chan string, packetsPerFamily)
ipv6Received := make(chan string, packetsPerFamily)
startGate := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv4Peer.ReadFrom(buf)
if err != nil {
return
}
ipv4Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv6Peer.ReadFrom(buf)
if err != nil {
return
}
ipv6Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
}
}()
close(startGate)
time.AfterFunc(5*time.Second, func() {
_ = ipv4Peer.SetReadDeadline(time.Now())
_ = ipv6Peer.SetReadDeadline(time.Now())
})
wg.Wait()
close(ipv4Received)
close(ipv6Received)
ipv4Count := 0
for pkt := range ipv4Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
ipv4Count++
}
ipv6Count := 0
for pkt := range ipv6Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
tests := []struct {
name string
network string
addr string
wantIPv4 bool
}{
{"IPv4 any", "udp4", "0.0.0.0:0", true},
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
{"IPv6 any", "udp6", "[::]:0", false},
{"IPv6 loopback", "udp6", "[::1]:0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
require.NoError(t, err)
conn, err := net.ListenUDP(tt.network, addr)
if err != nil {
t.Skipf("%s not available: %v", tt.network, err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
isIPv4 := localAddr.IP.To4() != nil
assert.Equal(t, tt.wantIPv4, isIPv4)
})
}
}
// helpers
func setupICEBind(t *testing.T) *ICEBind {
t.Helper()
transportNet, err := stdnet.NewNet()
require.NoError(t, err)
address := wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
t.Helper()
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
ipv4Conn.Close()
t.Skipf("IPv6 not available: %v", err)
}
return ipv4Conn, ipv6Conn
}
func createMsgPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, 1)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, 40)
}
return &msgs
},
}
}
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
t.Helper()
udpAddr, err := net.ResolveUDPAddr(network, addr)
require.NoError(t, err)
conn, err := net.ListenUDP(network, udpAddr)
require.NoError(t, err)
return conn
}

View File

@@ -3,8 +3,22 @@ package configurer
import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
// This is a shared helper used by both kernel and userspace configurers.
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
return wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
PublicKey: peerKey,
PresharedKey: &psk,
UpdateOnly: updateOnly,
}},
}
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {

View File

@@ -15,8 +15,6 @@ import (
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.configure(cfg)
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
PresharedKey: [32]byte(p.PresharedKey),
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint

View File

@@ -22,17 +22,16 @@ import (
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.device.IpcSet(toWgUserspaceString(cfg))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.Remove {
sb.WriteString("remove=true\n")
}
if p.UpdateOnly {
sb.WriteString("update_only=true\n")
}
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
}
return sb.String()
}
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
currentPeer.PresharedKey = true
if pskKey, err := hexToWireguardKey(val); err == nil {
currentPeer.PresharedKey = [32]byte(pskKey)
}
}
}
}

View File

@@ -12,7 +12,7 @@ type Peer struct {
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
PresharedKey [32]byte
}
type Stats struct {

View File

@@ -17,6 +17,7 @@ type WGConfigurer interface {
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)

View File

@@ -297,6 +297,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
// SetPresharedKey sets or updates the preshared key for a peer.
// If updateOnly is true, only updates existing peer; if false, creates or updates.
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)

View File

@@ -117,16 +117,29 @@ func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to convert endpoint address: %v", err)
} else {
p.wgCurrentUsed = ep
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, errors.New("nil address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}
func (p *ProxyBind) CloseConn() error {

View File

@@ -94,7 +94,9 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
if endpoint != nil && endpoint.IP != nil {
p.wgEndpointCurrentUsedAddr = endpoint
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()

View File

@@ -59,7 +59,6 @@ func NewConnectClient(
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
@@ -71,8 +70,8 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan)
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
}
// RunOnAndroid with main logic on mobile system
@@ -93,7 +92,7 @@ func (c *ConnectClient) RunOnAndroid(
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) RunOniOS(
@@ -111,10 +110,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -284,7 +283,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
relayURLs, token := parseRelayInfo(loginResp)
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -472,7 +471,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -507,7 +506,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
ProfileConfig: config,
}
if config.PreSharedKey != "" {

View File

@@ -28,8 +28,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -57,6 +59,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
@@ -223,10 +226,11 @@ type BundleGenerator struct {
internalConfig *profilemanager.Config
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logFile string
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
@@ -235,7 +239,6 @@ type BundleGenerator struct {
type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
@@ -244,7 +247,9 @@ type GeneratorDependencies struct {
InternalConfig *profilemanager.Config
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogFile string
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -260,10 +265,11 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
internalConfig: deps.InternalConfig,
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logFile: deps.LogFile,
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
@@ -309,13 +315,6 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add status: %w", err)
}
if g.statusRecorder != nil {
status := g.statusRecorder.GetFullStatus()
seedFromStatus(g.anonymizer, &status)
} else {
log.Debugf("no status recorder available for seeding")
}
if err := g.addConfig(); err != nil {
log.Errorf("failed to add config to debug bundle: %v", err)
}
@@ -332,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addCPUProfile(); err != nil {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
@@ -352,7 +355,7 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
@@ -401,11 +404,30 @@ func (g *BundleGenerator) addReadme() error {
}
func (g *BundleGenerator) addStatus() error {
if status := g.clientStatus; status != "" {
statusReader := strings.NewReader(status)
if g.statusRecorder != nil {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
if g.refreshStatus != nil {
g.refreshStatus()
}
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
seedFromStatus(g.anonymizer, &fullStatus)
} else {
log.Debugf("no status recorder available for seeding")
}
return nil
}
@@ -535,6 +557,19 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addCPUProfile() error {
if len(g.cpuProfile) == 0 {
return nil
}
reader := bytes.NewReader(g.cpuProfile)
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
return fmt.Errorf("add CPU profile to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
@@ -710,14 +745,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
}
func (g *BundleGenerator) addLogfile() error {
if g.logFile == "" {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")
return nil
}
logDir := filepath.Dir(g.logFile)
logDir := filepath.Dir(g.logPath)
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}

View File

@@ -0,0 +1,101 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -1,4 +1,4 @@
package server
package debug
import (
"context"
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
fileContent := []byte("test file content")
err := os.WriteFile(file, fileContent, 0640)
require.NoError(t, err)
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
require.NoError(t, err)
id := getURLHash(testURL)
require.Contains(t, key, id+"/")

View File

@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
if peer.PresharedKey != [32]byte{} {
sb.WriteString(" preshared key: (hidden)\n")
}
}

View File

@@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 {
logger.Debug("received local resolver request with no question")
@@ -120,7 +123,7 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
return dns.RcodeSuccess
}
@@ -164,11 +167,15 @@ func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dn
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
if !exists && supportsWildcard(qType) {
testWild := transformDomainToWildcard(string(domainName))
_, exists = d.domains[domain.Domain(testWild)]
}
return exists
}
@@ -195,6 +202,16 @@ type lookupResult struct {
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock()
records, found := d.records[question]
usingWildcard := false
wildQuestion := transformToWildcard(question)
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
// If the domain exists with any record type, return NODATA instead of wildcard match.
if !found && supportsWildcard(question.Qtype) {
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
records, found = d.records[wildQuestion]
usingWildcard = found
}
}
if !found {
d.mu.RUnlock()
@@ -216,18 +233,53 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
q := question
if usingWildcard {
q = wildQuestion
}
records = d.records[q]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
d.records[q] = records
}
d.mu.Unlock()
}
if usingWildcard {
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
}
func transformToWildcard(question dns.Question) dns.Question {
wildQuestion := question
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
return wildQuestion
}
func transformDomainToWildcard(domain string) string {
s := strings.Split(domain, ".")
s[0] = "*"
return strings.Join(s, ".")
}
func supportsWildcard(queryType uint16) bool {
return queryType != dns.TypeNS && queryType != dns.TypeSOA
}
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
records := make([]dns.RR, len(wildRecords))
for i, record := range wildRecords {
copiedRecord := dns.Copy(record)
copiedRecord.Header().Name = originalName
records[i] = copiedRecord
}
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
@@ -237,6 +289,13 @@ func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Questio
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
wildQuestion := transformToWildcard(cnameQuestion)
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
}
}
if len(cnameRecords) == 0 {
break
}
@@ -303,7 +362,7 @@ func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targ
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName)) {
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
return lookupResult{rcode: dns.RcodeSuccess}
}

File diff suppressed because it is too large Load Diff

View File

@@ -71,6 +71,11 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
@@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
u.prepareRequest(r)
@@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if u.tryUpstreamServers(w, r, logger) {
return
ok, failures := u.tryUpstreamServers(w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
if !ok {
u.writeErrorResponse(w, r, logger)
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
@@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
}
}
return false
return false, failures
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
@@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
}()
if err != nil {
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
return u.handleUpstreamError(err, upstream, startTime)
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
return &upstreamFailure{upstream: upstream, reason: "no response"}
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
return nil
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
return &upstreamFailure{upstream: upstream, reason: err.Error()}
}
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
reason += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warn(timeoutMsg)
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
@@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
failedCount := len(failures)
failureSummary := formatFailures(failures)
if succeeded {
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
} else {
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
}
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
}
}
func formatFailures(failures []upstreamFailure) string {
parts := make([]string, 0, len(failures))
for _, f := range failures {
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
}
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
@@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -2,6 +2,7 @@ package dns
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
@@ -9,6 +10,8 @@ import (
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
@@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
}
type mockUpstreamResolverPerServer struct {
responses map[string]mockUpstreamResponse
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
@@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
t.Errorf("should be enabled")
}
}
func TestUpstreamResolver_Failover(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
testCases := []struct {
name string
upstream1 mockUpstreamResponse
upstream2 mockUpstreamResponse
expectedRcode int
expectAnswer bool
expectTrySecond bool
}{
{
name: "success on first upstream",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: false,
},
{
name: "SERVFAIL from first should try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "REFUSED from first should try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "NXDOMAIN from first should NOT try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeNameError,
expectAnswer: false,
expectTrySecond: false,
},
{
name: "timeout from first should try second",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "no response from first should try second",
upstream1: mockUpstreamResponse{msg: nil},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "both upstreams return SERVFAIL",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "both upstreams timeout",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{err: timeoutErr},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first SERVFAIL then timeout",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{err: timeoutErr},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first timeout then SERVFAIL",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first REFUSED then SERVFAIL",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var queriedUpstreams []string
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): tc.upstream1,
upstream2.String(): tc.upstream2,
},
rtt: time.Millisecond,
}
trackingClient := &trackingMockClient{
inner: mockClient,
queriedUpstreams: &queriedUpstreams,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, inputMSG)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
if tc.expectAnswer {
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
}
if tc.expectTrySecond {
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
} else {
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
}
})
}
}
type trackingMockClient struct {
inner *mockUpstreamResolverPerServer
queriedUpstreams *[]string
}
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
return t.inner.exchange(ctx, upstream, r)
}
func buildMockResponse(rcode int, answer string) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if rcode == dns.RcodeSuccess && answer != "" {
m.Answer = []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP(answer),
},
}
}
return m
}
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
upstream := netip.MustParseAddrPort("192.0.2.1:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, inputMSG)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
failures []upstreamFailure
expected string
}{
{
name: "empty slice",
failures: []upstreamFailure{},
expected: "",
},
{
name: "single failure",
failures: []upstreamFailure{
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
},
expected: "8.8.8.8:53=SERVFAIL",
},
{
name: "multiple failures",
failures: []upstreamFailure{
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
},
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := formatFailures(tc.failures)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
@@ -42,12 +43,14 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -132,6 +135,11 @@ type EngineConfig struct {
LazyConnectionEnabled bool
MTU uint16
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogPath string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -195,7 +203,8 @@ type Engine struct {
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
@@ -211,6 +220,9 @@ type Engine struct {
shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -224,7 +236,18 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -244,6 +267,7 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -312,6 +336,8 @@ func (e *Engine) Stop() error {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
// stop flow manager after wg interface is gone
@@ -479,6 +505,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
@@ -500,6 +531,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
@@ -828,9 +860,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if e.persistSyncResponse {
if enabled {
e.syncRespMux.Lock()
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
@@ -960,6 +1001,80 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
defer e.jobExecutorWG.Done()
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
resp := mgmProto.JobResponse{
ID: msg.ID,
Status: mgmProto.JobStatus_failed,
}
switch params := msg.WorkloadParameters.(type) {
case *mgmProto.JobRequest_Bundle:
bundleResult, err := e.handleBundle(params.Bundle)
if err != nil {
log.Errorf("handling bundle: %v", err)
resp.Reason = []byte(err.Error())
return &resp
}
resp.Status = mgmProto.JobStatus_succeeded
resp.WorkloadResults = bundleResult
return &resp
default:
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
return &resp
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Info("stopped receiving jobs from Management Service")
}()
log.Info("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
log.Infof("handle remote debug bundle request: %s", params.String())
syncResponse, err := e.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
bundleDeps := debug.GeneratorDependencies{
InternalConfig: e.config.ProfileConfig,
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
RefreshStatus: func() {
e.RunHealthProbes(true)
},
}
bundleJobParams := debug.BundleConfig{
Anonymize: params.Anonymize,
IncludeSystemInfo: true,
LogFileCount: uint32(params.LogFileCount),
}
waitFor := time.Duration(params.BundleForTime) * time.Minute
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
if err != nil {
return nil, err
}
response := &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: uploadKey,
},
}
return response, nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
@@ -1405,6 +1520,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
if e.rpManager != nil {
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized)
}
return peerConn, nil
@@ -1714,7 +1830,7 @@ func (e *Engine) getRosenpassAddr() string {
return ""
}
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock()
@@ -1728,23 +1844,8 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
log.Debugf("failed to refresh WireGuard stats: %v", err)
}
e.syncMsgMux.Unlock()
@@ -1848,8 +1949,8 @@ func (e *Engine) stopDNSServer() {
// SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
if enabled == e.persistSyncResponse {
return
@@ -1864,20 +1965,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
if !e.persistSyncResponse {
if !enabled {
return nil, errors.New("sync response persistence is disabled")
}
if e.latestSyncResponse == nil {
if latest == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
@@ -213,6 +214,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
@@ -1599,6 +1604,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
@@ -1622,7 +1628,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
@@ -1631,7 +1637,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -42,4 +42,5 @@ type wgIfaceBase interface {
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}

View File

@@ -88,8 +88,9 @@ type Conn struct {
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
rosenpassInitializedPresharedKeyValidator func(peerKey string) bool
statusRelay *worker.AtomicWorkerStatus
statusICE *worker.AtomicWorkerStatus
@@ -98,7 +99,10 @@ type Conn struct {
workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
wgWatcher *WGWatcher
wgWatcherWg sync.WaitGroup
wgWatcherCancel context.CancelFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
@@ -126,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
@@ -137,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
}
return conn, nil
@@ -162,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -231,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher()
if conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()
@@ -289,6 +297,13 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over
// PSK management for a peer. When this returns true, presharedKey() returns nil
// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK.
func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) {
conn.rosenpassInitializedPresharedKeyValidator = handler
}
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
@@ -366,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
@@ -390,6 +402,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
@@ -423,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
@@ -444,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
}
conn.statusICE.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
}
}
@@ -500,11 +509,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.enableWgWatcherIfNeeded()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
@@ -519,7 +524,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.handleRelayDisconnectedLocked()
}
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
func (conn *Conn) handleRelayDisconnectedLocked() {
if conn.ctx.Err() != nil {
return
}
@@ -545,6 +554,8 @@ func (conn *Conn) onRelayDisconnected() {
}
conn.statusRelay.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
@@ -563,6 +574,28 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
// Close the active connection based on current priority
switch conn.currentConnPriority {
case conntype.Relay:
conn.workerRelay.CloseConn()
conn.handleRelayDisconnectedLocked()
case conntype.ICEP2P, conntype.ICETurn:
conn.workerICE.Close()
default:
conn.Log.Debugf("No active connection to close on WG timeout")
}
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
@@ -689,6 +722,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}
func (conn *Conn) enableWgWatcherIfNeeded() {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
}()
}
}
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
@@ -759,10 +811,24 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return conn.config.WgConfig.PreSharedKey
}
// If Rosenpass has already set a PSK for this peer, return nil to prevent
// UpdatePeer from overwriting the Rosenpass-managed key.
if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) {
return nil
}
// Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to
// Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum
// key is cryptographically bound to the original secret.
if conn.config.WgConfig.PreSharedKey != nil {
return conn.config.WgConfig.PreSharedKey
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey
return nil
}
return determKey

View File

@@ -284,3 +284,27 @@ func TestConn_presharedKey(t *testing.T) {
})
}
}
func TestConn_presharedKey_RosenpassManaged(t *testing.T) {
conn := Conn{
config: ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")},
},
}
// When Rosenpass has already initialized the PSK for this peer,
// presharedKey must return nil to avoid UpdatePeer overwriting it.
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true }
if k := conn.presharedKey([]byte("remote")); k != nil {
t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k)
}
// When Rosenpass hasn't taken over yet, presharedKey should provide
// a non-nil initial key (deterministic or from NetBird PSK).
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false }
if k := conn.presharedKey([]byte("remote")); k == nil {
t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK")
}
}

View File

@@ -1145,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
return d.wgIface.FullStats()
}
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
// and updates the cached peer states. This ensures accurate handshake times and
// transfer statistics in status reports without running full health probes.
func (d *Status) RefreshWireGuardStats() error {
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil // silently skip if interface not set
}
stats, err := d.wgIface.FullStats()
if err != nil {
return fmt.Errorf("get wireguard stats: %w", err)
}
// Update each peer's WireGuard statistics
for _, peerStats := range stats.Peers {
peerState, ok := d.peers[peerStats.PublicKey]
if !ok {
continue
}
peerState.LastWireguardHandshake = peerStats.LastHandshake
peerState.BytesRx = peerStats.RxBytes
peerState.BytesTx = peerStats.TxBytes
d.peers[peerStats.PublicKey] = peerState
}
return nil
}
type EventQueue struct {
maxSize int
events []*proto.SystemEvent

View File

@@ -30,10 +30,8 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
enabled bool
muEnabled sync.RWMutex
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
w.log.Debugf("enable WireGuard watcher")
enabledTime := time.Now()
w.enabled = true
w.muEnabled.Unlock()
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancel == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancel()
w.ctxCancel = nil
// IsEnabled returns true if the WireGuard watcher is currently enabled
func (w *WGWatcher) IsEnabled() bool {
w.muEnabled.RLock()
defer w.muEnabled.RUnlock()
return w.enabled
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
lastHandshake := initialHandshake
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
return
}
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
elapsed := calcElapsed(enabledTime, *handshake)
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}
// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}
// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
return nil, false
}
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
}
return wgState.LastHandshake, nil
}
// calcElapsed calculates elapsed time since watcher was enabled.
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
func calcElapsed(enabledTime, handshake time.Time) float64 {
elapsed := handshake.Sub(enabledTime).Seconds()
if elapsed < 0 {
elapsed = 0
}
return elapsed
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"context"
"sync"
"testing"
"time"
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}
func TestWGWatcher_ReEnable(t *testing.T) {
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
}()
cancel()
wg.Wait()
// Re-enable with a new context
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()
go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addrString := pair.Remote.Address()
parsed, err := netip.ParseAddr(addrString)
if (err == nil) && (parsed.Is6()) {
addrString = fmt.Sprintf("[%s]", addrString)
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
}
}
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
sessionID := w.SessionID()
stats := agent.GetCandidatePairsStats()
localCandidates, _ := agent.GetLocalCandidates()
remoteCandidates, _ := agent.GetRemoteCandidates()
localMap := make(map[string]ice.Candidate)
for _, c := range localCandidates {
localMap[c.ID()] = c
}
remoteMap := make(map[string]ice.Candidate)
for _, c := range remoteCandidates {
remoteMap[c.ID()] = c
}
for _, stat := range stats {
if stat.State == ice.CandidatePairStateSucceeded {
local, lok := localMap[stat.LocalCandidateID]
remote, rok := remoteMap[stat.RemoteCandidateID]
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}
}
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
w.logSuccessfulPaths(agent)
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to

View File

@@ -30,11 +30,9 @@ type WorkerRelay struct {
relayLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
config: config,
conn: conn,
relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
}
return r
}
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
})
}
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
}
func (w *WorkerRelay) DisableWgWatcher() {
w.wgWatcher.DisableWgWatcher()
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress()
}
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
}
}
func (w *WorkerRelay) onWGDisconnected() {
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.conn.onRelayDisconnected()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() {
return false
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
}
func (w *WorkerRelay) onRelayClientDisconnected() {
w.wgWatcher.DisableWgWatcher()
go w.conn.onRelayDisconnected()
}

View File

@@ -17,6 +17,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
defaultLog = slog.LevelInfo
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
)
func hashRosenpassKey(key []byte) string {
hasher := sha256.New()
hasher.Write(key)
@@ -34,6 +39,7 @@ type Manager struct {
server *rp.Server
lock sync.Mutex
port int
wgIface PresharedKeySetter
}
// NewManager creates a new Rosenpass manager
@@ -44,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
}
rpKeyHash := hashRosenpassKey(public)
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
}
@@ -100,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
func (m *Manager) generateConfig() (rp.Config, error) {
opts := &slog.HandlerOptions{
Level: slog.LevelDebug,
Level: getLogLevel(),
}
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
cfg := rp.Config{Logger: logger}
@@ -109,7 +115,13 @@ func (m *Manager) generateConfig() (rp.Config, error) {
cfg.SecretKey = m.ssk
cfg.Peers = []rp.PeerConfig{}
m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName)
m.lock.Lock()
m.rpWgHandler = NewNetbirdHandler()
if m.wgIface != nil {
m.rpWgHandler.SetInterface(m.wgIface)
}
m.lock.Unlock()
cfg.Handlers = []rp.Handler{m.rpWgHandler}
@@ -126,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) {
return cfg, nil
}
func getLogLevel() slog.Level {
level, ok := os.LookupEnv(defaultLogLevelVar)
if !ok {
return defaultLog
}
switch strings.ToLower(level) {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
return defaultLog
}
}
func (m *Manager) OnDisconnected(peerKey string) {
m.lock.Lock()
defer m.lock.Unlock()
@@ -172,6 +204,20 @@ func (m *Manager) Close() error {
return nil
}
// SetInterface sets the WireGuard interface for the rosenpass handler.
// This can be called before or after Run() - the interface will be stored
// and passed to the handler when it's created or updated immediately if
// already running.
func (m *Manager) SetInterface(iface PresharedKeySetter) {
m.lock.Lock()
defer m.lock.Unlock()
m.wgIface = iface
if m.rpWgHandler != nil {
m.rpWgHandler.SetInterface(iface)
}
}
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
m.lock.Lock()
@@ -192,6 +238,20 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [
}
}
// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for the given WireGuard peer.
func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool {
m.lock.Lock()
defer m.lock.Unlock()
peerID, ok := m.rpPeerIDs[wireGuardPubKey]
if !ok || peerID == nil {
return false
}
return m.rpWgHandler.IsPeerInitialized(*peerID)
}
func findRandomAvailableUDPPort() (int, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {

View File

@@ -1,46 +1,50 @@
package rosenpass
import (
"fmt"
"log/slog"
"sync"
rp "cunicu.li/go-rosenpass"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers.
// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface.
type PresharedKeySetter interface {
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}
type wireGuardPeer struct {
Interface string
PublicKey rp.Key
}
type NetbirdHandler struct {
ifaceName string
client *wgctrl.Client
peers map[rp.PeerID]wireGuardPeer
presharedKey [32]byte
mu sync.Mutex
iface PresharedKeySetter
peers map[rp.PeerID]wireGuardPeer
initializedPeers map[rp.PeerID]bool
}
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
hdlr = &NetbirdHandler{
ifaceName: wgIfaceName,
peers: map[rp.PeerID]wireGuardPeer{},
func NewNetbirdHandler() *NetbirdHandler {
return &NetbirdHandler{
peers: map[rp.PeerID]wireGuardPeer{},
initializedPeers: map[rp.PeerID]bool{},
}
}
if preSharedKey != nil {
hdlr.presharedKey = *preSharedKey
}
if hdlr.client, err = wgctrl.New(); err != nil {
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
}
return hdlr, nil
// SetInterface sets the WireGuard interface for the handler.
// This must be called after the WireGuard interface is created.
func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) {
h.mu.Lock()
defer h.mu.Unlock()
h.iface = iface
}
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
h.mu.Lock()
defer h.mu.Unlock()
h.peers[pid] = wireGuardPeer{
Interface: intf,
PublicKey: pk,
@@ -48,79 +52,61 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
}
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.peers, pid)
delete(h.initializedPeers, pid)
}
// IsPeerInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for this peer.
func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool {
h.mu.Lock()
defer h.mu.Unlock()
return h.initializedPeers[pid]
}
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
log.Debug("Handshake complete")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
key, _ := rp.GeneratePresharedKey()
log.Debug("Handshake expired")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
h.mu.Lock()
iface := h.iface
wg, ok := h.peers[pid]
isInitialized := h.initializedPeers[pid]
h.mu.Unlock()
if iface == nil {
log.Warn("rosenpass: interface not set, cannot update preshared key")
return
}
if !ok {
return
}
device, err := h.client.Device(h.ifaceName)
if err != nil {
log.Errorf("Failed to get WireGuard device: %v", err)
peerKey := wgtypes.Key(wg.PublicKey).String()
pskKey := wgtypes.Key(psk)
// Use updateOnly=true for later rotations (peer already has Rosenpass PSK)
// Use updateOnly=false for first rotation (peer has original/empty PSK)
if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
return
}
config := []wgtypes.PeerConfig{
{
UpdateOnly: true,
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
},
}
for _, peer := range device.Peers {
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
config = []wgtypes.PeerConfig{
{
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
Endpoint: peer.Endpoint,
AllowedIPs: peer.AllowedIPs,
},
}
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
Remove: true,
PublicKey: wgtypes.Key(wg.PublicKey),
},
},
})
if err != nil {
slog.Debug("Failed to remove peer")
return
}
}
// Mark peer as isInitialized after the successful first rotation
if !isInitialized {
h.mu.Lock()
if _, exists := h.peers[pid]; exists {
h.initializedPeers[pid] = true
}
}
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: config,
}); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
h.mu.Unlock()
}
}
func publicKeyEmpty(key wgtypes.Key) bool {
for _, b := range key {
if b != 0 {
return false
}
}
return true
}

View File

@@ -0,0 +1,76 @@
package jobexec
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
)
var (
ErrJobNotImplemented = errors.New("job not implemented")
)
type Executor struct {
}
func NewExecutor() *Executor {
return &Executor{}
}
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
if waitForDuration > MaxBundleWaitTime {
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
waitForDuration = MaxBundleWaitTime
}
if waitForDuration > 0 {
if err := waitFor(ctx, waitForDuration); err != nil {
return "", err
}
}
log.Infof("execute debug bundle generation")
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
if err != nil {
log.Errorf("failed to upload debug bundle: %v", err)
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle has been generated successfully")
return key, nil
}
func waitFor(ctx context.Context, duration time.Duration) error {
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
log.Infof("wait cancelled: %v", ctx.Err())
return ctx.Err()
}
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v3.21.12
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -2757,7 +2757,6 @@ func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule {
type DebugBundleRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
@@ -2802,13 +2801,6 @@ func (x *DebugBundleRequest) GetAnonymize() bool {
return false
}
func (x *DebugBundleRequest) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
func (x *DebugBundleRequest) GetSystemInfo() bool {
if x != nil {
return x.SystemInfo
@@ -5372,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
// StartCPUProfileRequest for starting CPU profiling
type StartCPUProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StartCPUProfileRequest) Reset() {
*x = StartCPUProfileRequest{}
mi := &file_daemon_proto_msgTypes[79]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StartCPUProfileRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartCPUProfileRequest) ProtoMessage() {}
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
}
// StartCPUProfileResponse confirms CPU profiling has started
type StartCPUProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StartCPUProfileResponse) Reset() {
*x = StartCPUProfileResponse{}
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StartCPUProfileResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartCPUProfileResponse) ProtoMessage() {}
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
}
// StopCPUProfileRequest for stopping CPU profiling
type StopCPUProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopCPUProfileRequest) Reset() {
*x = StopCPUProfileRequest{}
mi := &file_daemon_proto_msgTypes[81]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopCPUProfileRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopCPUProfileRequest) ProtoMessage() {}
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[81]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{81}
}
// StopCPUProfileResponse confirms CPU profiling has stopped
type StopCPUProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopCPUProfileResponse) Reset() {
*x = StopCPUProfileResponse{}
mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopCPUProfileResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopCPUProfileResponse) ProtoMessage() {}
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{82}
}
type InstallerResultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -5380,7 +5520,7 @@ type InstallerResultRequest struct {
func (x *InstallerResultRequest) Reset() {
*x = InstallerResultRequest{}
mi := &file_daemon_proto_msgTypes[79]
mi := &file_daemon_proto_msgTypes[83]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5392,7 +5532,7 @@ func (x *InstallerResultRequest) String() string {
func (*InstallerResultRequest) ProtoMessage() {}
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
mi := &file_daemon_proto_msgTypes[83]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5405,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
return file_daemon_proto_rawDescGZIP(), []int{83}
}
type InstallerResultResponse struct {
@@ -5418,7 +5558,7 @@ type InstallerResultResponse struct {
func (x *InstallerResultResponse) Reset() {
*x = InstallerResultResponse{}
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[84]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5430,7 +5570,7 @@ func (x *InstallerResultResponse) String() string {
func (*InstallerResultResponse) ProtoMessage() {}
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[84]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5443,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
return file_daemon_proto_rawDescGZIP(), []int{84}
}
func (x *InstallerResultResponse) GetSuccess() bool {
@@ -5470,7 +5610,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5482,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5773,10 +5913,9 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
@@ -6003,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
"\x16StartCPUProfileRequest\"\x19\n" +
"\x17StartCPUProfileResponse\"\x17\n" +
"\x15StopCPUProfileRequest\"\x18\n" +
"\x16StopCPUProfileResponse\"\x18\n" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
@@ -6015,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xb4\x13\n" +
"\x05TRACE\x10\a2\xdd\x14\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6050,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
@@ -6067,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6152,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
nil, // 85: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
nil, // 87: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
nil, // 89: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
nil, // 91: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6177,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6189,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6226,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
67, // [67:100] is the sub-list for method output_type
34, // [34:67] is the sub-list for method input_type
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
69, // [69:104] is the sub-list for method output_type
34, // [34:69] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6292,7 +6445,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 84,
NumMessages: 88,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -94,6 +94,12 @@ service DaemonService {
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
// StartCPUProfile starts CPU profiling in the daemon
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
// StopCPUProfile stops CPU profiling in the daemon
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
@@ -455,7 +461,6 @@ message ForwardingRulesResponse {
// DebugBundler
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
@@ -777,6 +782,18 @@ message WaitJWTTokenResponse {
int64 expiresIn = 3;
}
// StartCPUProfileRequest for starting CPU profiling
message StartCPUProfileRequest {}
// StartCPUProfileResponse confirms CPU profiling has started
message StartCPUProfileResponse {}
// StopCPUProfileRequest for stopping CPU profiling
message StopCPUProfileRequest {}
// StopCPUProfileResponse confirms CPU profiling has stopped
message StopCPUProfileResponse {}
message InstallerResultRequest {
}

View File

@@ -70,6 +70,10 @@ type DaemonServiceClient interface {
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
// StartCPUProfile starts CPU profiling in the daemon
StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error)
// StopCPUProfile stops CPU profiling in the daemon
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
@@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
return out, nil
}
func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) {
out := new(StartCPUProfileResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) {
out := new(StopCPUProfileResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
out := new(OSLifecycleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
@@ -458,6 +480,10 @@ type DaemonServiceServer interface {
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
// StartCPUProfile starts CPU profiling in the daemon
StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error)
// StopCPUProfile stops CPU profiling in the daemon
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
@@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented")
}
func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented")
}
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
@@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler)
}
func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StartCPUProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).StartCPUProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/StartCPUProfile",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StopCPUProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).StopCPUProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/StopCPUProfile",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(OSLifecycleRequest)
if err := dec(in); err != nil {
@@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
{
MethodName: "StartCPUProfile",
Handler: _DaemonService_StartCPUProfile_Handler,
},
{
MethodName: "StopCPUProfile",
Handler: _DaemonService_StopCPUProfile_Handler,
},
{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,

View File

@@ -3,25 +3,19 @@
package server
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"runtime/pprof"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -32,16 +26,37 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
log.Warnf("failed to get latest sync response: %v", err)
}
var cpuProfileData []byte
if s.cpuProfileBuf != nil && !s.cpuProfiling {
cpuProfileData = s.cpuProfileBuf.Bytes()
defer func() {
s.cpuProfileBuf = nil
}()
}
// Prepare refresh callback for health probes
var refreshStatus func()
if s.connectClient != nil {
engine := s.connectClient.Engine()
if engine != nil {
refreshStatus = func() {
log.Debug("refreshing system health status for debug bundle")
engine.RunHealthProbes(true)
}
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: s.config,
StatusRecorder: s.statusRecorder,
SyncResponse: syncResponse,
LogFile: s.logFile,
LogPath: s.logFile,
CPUProfile: cpuProfileData,
RefreshStatus: refreshStatus,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),
ClientStatus: req.GetStatus(),
IncludeSystemInfo: req.GetSystemInfo(),
LogFileCount: req.GetLogFileCount(),
},
@@ -55,7 +70,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil
}
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
@@ -66,92 +81,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
}
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}
// GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()
@@ -204,3 +133,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
return cClient.GetLatestSyncResponse()
}
// StartCPUProfile starts CPU profiling in the daemon.
func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.cpuProfiling {
return nil, fmt.Errorf("CPU profiling already in progress")
}
s.cpuProfileBuf = &bytes.Buffer{}
s.cpuProfiling = true
if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil {
s.cpuProfileBuf = nil
s.cpuProfiling = false
return nil, fmt.Errorf("start CPU profile: %w", err)
}
log.Info("CPU profiling started")
return &proto.StartCPUProfileResponse{}, nil
}
// StopCPUProfile stops CPU profiling in the daemon.
func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.cpuProfiling {
return nil, fmt.Errorf("CPU profiling not in progress")
}
pprof.StopCPUProfile()
s.cpuProfiling = false
if s.cpuProfileBuf != nil {
log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len())
}
return &proto.StopCPUProfileResponse{}, nil
}

View File

@@ -1,6 +1,7 @@
package server
import (
"bytes"
"context"
"errors"
"fmt"
@@ -13,9 +14,8 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
@@ -67,7 +67,7 @@ type Server struct {
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
clientGiveUpChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
connectClient *internal.ConnectClient
@@ -78,6 +78,9 @@ type Server struct {
persistSyncResponse bool
isSessionActive atomic.Bool
cpuProfileBuf *bytes.Buffer
cpuProfiling bool
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
@@ -793,9 +796,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
// Down engine work in the daemon.
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
giveUpChan := s.clientGiveUpChan
if err := s.cleanupConnection(); err != nil {
s.mutex.Unlock()
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
@@ -804,6 +809,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.mutex.Unlock()
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
// The giveUpChan is closed at the end of connectWithRetryRuns.
if giveUpChan != nil {
select {
case <-giveUpChan:
log.Debugf("client goroutine finished successfully")
case <-time.After(5 * time.Second):
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
}
}
return &proto.DownResponse{}, nil
}
@@ -1308,6 +1327,10 @@ func (s *Server) runProbes(waitForProbeResult bool) {
if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
} else {
if err := s.statusRecorder.RefreshWireGuardStats(); err != nil {
log.Debugf("failed to refresh WireGuard stats: %v", err)
}
}
}
@@ -1521,7 +1544,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
return err
}
return nil

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -306,6 +307,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -317,7 +320,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
@@ -326,7 +329,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -11,8 +11,12 @@ import (
"strings"
"time"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/yaml.v3"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
@@ -116,9 +120,7 @@ type OutputOverview struct {
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
}
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
pbFullStatus := resp.GetFullStatus()
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
managementState := pbFullStatus.GetManagementState()
managementOverview := ManagementStateOutput{
URL: managementState.GetURL(),
@@ -134,13 +136,13 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
DaemonVersion: daemonVersion,
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
@@ -489,6 +491,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
if o.NumberOfForwardingRules > 0 {
forwardingRulesString = fmt.Sprintf("Forwarding rules: %d\n", o.NumberOfForwardingRules)
}
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
@@ -512,7 +519,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"Networks: %s\n"+
"Forwarding rules: %d\n"+
"%s"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
o.DaemonVersion,
@@ -529,7 +536,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
lazyConnectionEnabledStatus,
sshServerStatus,
networks,
o.NumberOfForwardingRules,
forwardingRulesString,
peersCountString,
)
return summary
@@ -553,6 +560,94 @@ func (o *OutputOverview) FullDetailSummary() string {
)
}
func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""

View File

@@ -238,7 +238,7 @@ var overview = OutputOverview{
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "")
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "")
assert.Equal(t, overview, convertedResult)
}
@@ -567,7 +567,6 @@ Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -592,7 +591,6 @@ Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
`

View File

@@ -1033,7 +1033,7 @@ func (s *serviceClient) onTrayReady() {
s.mDown.Disable()
systray.AddSeparator()
s.mSettings = systray.AddMenuItem("Settings", settingsMenuDescr)
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
@@ -1060,7 +1060,7 @@ func (s *serviceClient) onTrayReady() {
}
s.exitNodeMu.Lock()
s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr)
s.mExitNode = systray.AddMenuItem("Exit Node", disabledMenuDescr)
s.mExitNode.Disable()
s.exitNodeMu.Unlock()
@@ -1261,7 +1261,6 @@ func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {
if enabled {
s.mSettings.Enable()
s.mSettings.SetTooltip(settingsMenuDescr)
} else {
s.mSettings.Hide()
s.mSettings.SetTooltip("Settings are disabled by daemon")

View File

@@ -1,8 +1,6 @@
package main
const (
settingsMenuDescr = "Settings of the application"
profilesMenuDescr = "Manage your profiles"
allowSSHMenuDescr = "Allow SSH connections"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
@@ -11,7 +9,7 @@ const (
notificationsMenuDescr = "Enable notifications"
advancedSettingsMenuDescr = "Advanced settings of the application"
debugBundleMenuDescr = "Create and open debug information bundle"
exitNodeMenuDescr = "Select exit node for routing traffic"
disabledMenuDescr = ""
networksMenuDescr = "Open the networks management window"
latestVersionMenuDescr = "Download latest version"
quitMenuDescr = "Quit the client app"

View File

@@ -18,9 +18,7 @@ import (
"github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
)
@@ -291,19 +289,18 @@ func (s *serviceClient) handleRunForDuration(
return
}
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
if err != nil {
defer s.restoreServiceState(conn, initialState)
if err := s.collectDebugData(conn, initialState, params, progressUI); err != nil {
handleError(progressUI, err.Error())
return
}
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
if err := s.createDebugBundleFromCollection(conn, params, progressUI); err != nil {
handleError(progressUI, err.Error())
return
}
s.restoreServiceState(conn, initialState)
progressUI.statusLabel.SetText("Bundle created successfully")
}
@@ -409,6 +406,10 @@ func (s *serviceClient) configureServiceForDebug(
}
time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err)
}
return nil
}
@@ -417,68 +418,37 @@ func (s *serviceClient) collectDebugData(
state *debugInitialState,
params *debugCollectionParams,
progress *progressUI,
) (string, error) {
) error {
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
defer cancel()
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return "", err
return err
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
}
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = overview.FullDetailSummary()
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
wg.Wait()
progress.progressBar.Hide()
progress.statusLabel.SetText("Collecting debug data...")
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get pre-down status: %v", err)
if _, err := conn.StopCPUProfile(s.ctx, &proto.StopCPUProfileRequest{}); err != nil {
log.Warnf("failed to stop CPU profiling: %v", err)
}
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = overview.FullDetailSummary()
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
return statusOutput, nil
return nil
}
// Create the debug bundle with collected data
func (s *serviceClient) createDebugBundleFromCollection(
conn proto.DaemonServiceClient,
params *debugCollectionParams,
statusOutput string,
progress *progressUI,
) error {
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
Status: statusOutput,
SystemInfo: params.systemInfo,
}
@@ -581,26 +551,8 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err)
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
}
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = overview.FullDetailSummary()
}
request := &proto.DebugBundleRequest{
Anonymize: anonymize,
Status: statusOutput,
SystemInfo: systemInfo,
}

View File

@@ -99,6 +99,8 @@ func (h *eventHandler) handleConnectClick() {
func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable()
h.client.exitNodeStates = []exitNodeState{}
if h.client.connectCancel != nil {
log.Debugf("cancelling ongoing connect operation")
h.client.connectCancel()

View File

@@ -390,7 +390,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" {
s.mExitNode.Remove()
s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr)
s.mExitNode = systray.AddMenuItem("Exit Node", disabledMenuDescr)
}
var showDeselectAll bool

View File

@@ -12,7 +12,6 @@ import (
"google.golang.org/protobuf/encoding/protojson"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/proto"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http"
@@ -350,12 +349,8 @@ func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error)
}
pbFullStatus := fullStatus.ToProto()
statusResp := &proto.StatusResponse{
DaemonVersion: version.NetbirdVersion(),
FullStatus: pbFullStatus,
}
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, false, version.NetbirdVersion(), "", nil, nil, nil, "", ""), nil
}
// createStatusMethod creates the status method that returns JSON

4
go.mod
View File

@@ -68,8 +68,9 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -141,6 +142,7 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect

12
go.sum
View File

@@ -35,12 +35,15 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
@@ -87,6 +90,7 @@ github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE=
github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -320,6 +324,7 @@ github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7X
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -401,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -416,6 +421,8 @@ github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXq
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI=
github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc=
github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -522,6 +529,7 @@ github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=

356
idp/dex/connector.go Normal file
View File

@@ -0,0 +1,356 @@
// Package dex provides an embedded Dex OIDC identity provider.
package dex
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/dexidp/dex/storage"
)
// ConnectorConfig represents the configuration for an identity provider connector
type ConnectorConfig struct {
// ID is the unique identifier for the connector
ID string
// Name is a human-readable name for the connector
Name string
// Type is the connector type (oidc, google, microsoft)
Type string
// Issuer is the OIDC issuer URL (for OIDC-based connectors)
Issuer string
// ClientID is the OAuth2 client ID
ClientID string
// ClientSecret is the OAuth2 client secret
ClientSecret string
// RedirectURI is the OAuth2 redirect URI
RedirectURI string
}
// CreateConnector creates a new connector in Dex storage.
// It maps the connector config to the appropriate Dex connector type and configuration.
func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
// Fill in the redirect URI if not provided
if cfg.RedirectURI == "" {
cfg.RedirectURI = p.GetRedirectURI()
}
storageConn, err := p.buildStorageConnector(cfg)
if err != nil {
return nil, fmt.Errorf("failed to build connector: %w", err)
}
if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
return nil, fmt.Errorf("failed to create connector: %w", err)
}
p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
return cfg, nil
}
// GetConnector retrieves a connector by ID from Dex storage.
func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
conn, err := p.storage.GetConnector(ctx, id)
if err != nil {
if err == storage.ErrNotFound {
return nil, err
}
return nil, fmt.Errorf("failed to get connector: %w", err)
}
return p.parseStorageConnector(conn)
}
// ListConnectors returns all connectors from Dex storage (excluding the local connector).
func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list connectors: %w", err)
}
result := make([]*ConnectorConfig, 0, len(connectors))
for _, conn := range connectors {
// Skip the local password connector
if conn.ID == "local" && conn.Type == "local" {
continue
}
cfg, err := p.parseStorageConnector(conn)
if err != nil {
p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
continue
}
result = append(result, cfg)
}
return result, nil
}
// UpdateConnector updates an existing connector in Dex storage.
// It merges incoming updates with existing values to prevent data loss on partial updates.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
oldCfg, err := p.parseStorageConnector(old)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
}
mergeConnectorConfig(cfg, oldCfg)
storageConn, err := p.buildStorageConnector(cfg)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
}
return storageConn, nil
}); err != nil {
return fmt.Errorf("failed to update connector: %w", err)
}
p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
return nil
}
// mergeConnectorConfig preserves existing values for empty fields in the update.
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
if cfg.ClientSecret == "" {
cfg.ClientSecret = oldCfg.ClientSecret
}
if cfg.RedirectURI == "" {
cfg.RedirectURI = oldCfg.RedirectURI
}
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
cfg.Issuer = oldCfg.Issuer
}
if cfg.ClientID == "" {
cfg.ClientID = oldCfg.ClientID
}
if cfg.Name == "" {
cfg.Name = oldCfg.Name
}
}
// DeleteConnector removes a connector from Dex storage.
func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
// Prevent deletion of the local connector
if id == "local" {
return fmt.Errorf("cannot delete the local password connector")
}
if err := p.storage.DeleteConnector(ctx, id); err != nil {
return fmt.Errorf("failed to delete connector: %w", err)
}
p.logger.Info("connector deleted", "id", id)
return nil
}
// GetRedirectURI returns the default redirect URI for connectors.
func (p *Provider) GetRedirectURI() string {
if p.config == nil {
return ""
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
}
// buildStorageConnector creates a storage.Connector from ConnectorConfig.
// It handles the type-specific configuration for each connector type.
func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
var dexType string
var configData []byte
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
dexType = "google"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
case "microsoft":
dexType = "microsoft"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
default:
return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
}
if err != nil {
return storage.Connector{}, err
}
return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
}
// resolveRedirectURI returns the redirect URI, using a default if not provided
func (p *Provider) resolveRedirectURI(redirectURI string) string {
if redirectURI != "" || p.config == nil {
return redirectURI
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
}
// buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
"insecureSkipEmailVerified": true,
}
switch cfg.Type {
case "zitadel":
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return encodeConnectorConfig(oidcConfig)
}
// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
return encodeConnectorConfig(map[string]interface{}{
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
})
}
// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
// It infers the original identity provider type from the Dex connector type and ID.
func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
cfg := &ConnectorConfig{
ID: conn.ID,
Name: conn.Name,
}
if len(conn.Config) == 0 {
cfg.Type = conn.Type
return cfg, nil
}
var configMap map[string]interface{}
if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
return nil, fmt.Errorf("failed to parse connector config: %w", err)
}
// Extract common fields
if v, ok := configMap["clientID"].(string); ok {
cfg.ClientID = v
}
if v, ok := configMap["clientSecret"].(string); ok {
cfg.ClientSecret = v
}
if v, ok := configMap["redirectURI"].(string); ok {
cfg.RedirectURI = v
}
if v, ok := configMap["issuer"].(string); ok {
cfg.Issuer = v
}
// Infer the original identity provider type from Dex connector type and ID
cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
return cfg, nil
}
// inferIdentityProviderType determines the original identity provider type
// based on the Dex connector type, connector ID, and configuration.
func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
if dexType != "oidc" {
return dexType
}
return inferOIDCProviderType(connectorID)
}
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}
}
return "oidc"
}
// encodeConnectorConfig serializes connector config to JSON bytes.
func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
return json.Marshal(config)
}
// decodeConnectorConfig deserializes connector config from JSON bytes.
func decodeConnectorConfig(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// ensureLocalConnector creates a local (password) connector if it doesn't exist
func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
// Check specifically for the local connector
_, err := stor.GetConnector(ctx, "local")
if err == nil {
// Local connector already exists
return nil
}
if !errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("failed to get local connector: %w", err)
}
// Create a local connector for password authentication
localConnector := storage.Connector{
ID: "local",
Type: "local",
Name: "Email",
}
if err := stor.CreateConnector(ctx, localConnector); err != nil {
return fmt.Errorf("failed to create local connector: %w", err)
}
return nil
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {
storConn, err := conn.ToStorageConnector()
if err != nil {
return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
}
_, err = stor.GetConnector(ctx, conn.ID)
if err == storage.ErrNotFound {
if err := stor.CreateConnector(ctx, storConn); err != nil {
return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
}
continue
}
if err != nil {
return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
}
if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
old.Name = storConn.Name
old.Config = storConn.Config
return old, nil
}); err != nil {
return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
}
}
return nil
}

View File

@@ -4,7 +4,6 @@ package dex
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -245,34 +244,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
return nil
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {
storConn, err := conn.ToStorageConnector()
if err != nil {
return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
}
_, err = stor.GetConnector(ctx, conn.ID)
if errors.Is(err, storage.ErrNotFound) {
if err := stor.CreateConnector(ctx, storConn); err != nil {
return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
}
continue
}
if err != nil {
return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
}
if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
old.Name = storConn.Name
old.Config = storConn.Config
return old, nil
}); err != nil {
return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
}
}
return nil
}
// buildDexConfig creates a server.Config with defaults applied
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
cfg := yamlConfig.ToServerConfig(stor, logger)
@@ -613,294 +584,37 @@ func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) {
return p.storage.ListPasswords(ctx)
}
// ensureLocalConnector creates a local (password) connector if none exists
func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
connectors, err := stor.ListConnectors(ctx)
// UpdateUserPassword updates the password for a user identified by userID.
// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID.
// It verifies the current password before updating.
func (p *Provider) UpdateUserPassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
// Get the user by ID to find their email
user, err := p.GetUserByID(ctx, userID)
if err != nil {
return fmt.Errorf("failed to list connectors: %w", err)
return fmt.Errorf("failed to get user: %w", err)
}
// If any connector exists, we're good
if len(connectors) > 0 {
return nil
// Verify old password
if err := bcrypt.CompareHashAndPassword(user.Hash, []byte(oldPassword)); err != nil {
return fmt.Errorf("current password is incorrect")
}
// Create a local connector for password authentication
localConnector := storage.Connector{
ID: "local",
Type: "local",
Name: "Email",
}
if err := stor.CreateConnector(ctx, localConnector); err != nil {
return fmt.Errorf("failed to create local connector: %w", err)
}
return nil
}
// ConnectorConfig represents the configuration for an identity provider connector
type ConnectorConfig struct {
// ID is the unique identifier for the connector
ID string
// Name is a human-readable name for the connector
Name string
// Type is the connector type (oidc, google, microsoft)
Type string
// Issuer is the OIDC issuer URL (for OIDC-based connectors)
Issuer string
// ClientID is the OAuth2 client ID
ClientID string
// ClientSecret is the OAuth2 client secret
ClientSecret string
// RedirectURI is the OAuth2 redirect URI
RedirectURI string
}
// CreateConnector creates a new connector in Dex storage.
// It maps the connector config to the appropriate Dex connector type and configuration.
func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
// Fill in the redirect URI if not provided
if cfg.RedirectURI == "" {
cfg.RedirectURI = p.GetRedirectURI()
}
storageConn, err := p.buildStorageConnector(cfg)
// Hash the new password
newHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to build connector: %w", err)
return fmt.Errorf("failed to hash new password: %w", err)
}
if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
return nil, fmt.Errorf("failed to create connector: %w", err)
}
p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
return cfg, nil
}
// GetConnector retrieves a connector by ID from Dex storage.
func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
conn, err := p.storage.GetConnector(ctx, id)
if err != nil {
if err == storage.ErrNotFound {
return nil, err
}
return nil, fmt.Errorf("failed to get connector: %w", err)
}
return p.parseStorageConnector(conn)
}
// ListConnectors returns all connectors from Dex storage (excluding the local connector).
func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list connectors: %w", err)
}
result := make([]*ConnectorConfig, 0, len(connectors))
for _, conn := range connectors {
// Skip the local password connector
if conn.ID == "local" && conn.Type == "local" {
continue
}
cfg, err := p.parseStorageConnector(conn)
if err != nil {
p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
continue
}
result = append(result, cfg)
}
return result, nil
}
// UpdateConnector updates an existing connector in Dex storage.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
storageConn, err := p.buildStorageConnector(cfg)
if err != nil {
return fmt.Errorf("failed to build connector: %w", err)
}
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
return storageConn, nil
}); err != nil {
return fmt.Errorf("failed to update connector: %w", err)
}
p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
return nil
}
// DeleteConnector removes a connector from Dex storage.
func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
// Prevent deletion of the local connector
if id == "local" {
return fmt.Errorf("cannot delete the local password connector")
}
if err := p.storage.DeleteConnector(ctx, id); err != nil {
return fmt.Errorf("failed to delete connector: %w", err)
}
p.logger.Info("connector deleted", "id", id)
return nil
}
// buildStorageConnector creates a storage.Connector from ConnectorConfig.
// It handles the type-specific configuration for each connector type.
func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
var dexType string
var configData []byte
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
dexType = "google"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
case "microsoft":
dexType = "microsoft"
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
default:
return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
}
if err != nil {
return storage.Connector{}, err
}
return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
}
// resolveRedirectURI returns the redirect URI, using a default if not provided
func (p *Provider) resolveRedirectURI(redirectURI string) string {
if redirectURI != "" || p.config == nil {
return redirectURI
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
}
// buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
"insecureSkipEmailVerified": true,
}
switch cfg.Type {
case "zitadel":
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return encodeConnectorConfig(oidcConfig)
}
// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
return encodeConnectorConfig(map[string]interface{}{
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
// Update the password in storage
err = p.storage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = newHash
return old, nil
})
}
// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
// It infers the original identity provider type from the Dex connector type and ID.
func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
cfg := &ConnectorConfig{
ID: conn.ID,
Name: conn.Name,
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
if len(conn.Config) == 0 {
cfg.Type = conn.Type
return cfg, nil
}
var configMap map[string]interface{}
if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
return nil, fmt.Errorf("failed to parse connector config: %w", err)
}
// Extract common fields
if v, ok := configMap["clientID"].(string); ok {
cfg.ClientID = v
}
if v, ok := configMap["clientSecret"].(string); ok {
cfg.ClientSecret = v
}
if v, ok := configMap["redirectURI"].(string); ok {
cfg.RedirectURI = v
}
if v, ok := configMap["issuer"].(string); ok {
cfg.Issuer = v
}
// Infer the original identity provider type from Dex connector type and ID
cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
return cfg, nil
}
// inferIdentityProviderType determines the original identity provider type
// based on the Dex connector type, connector ID, and configuration.
func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
if dexType != "oidc" {
return dexType
}
return inferOIDCProviderType(connectorID)
}
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}
}
return "oidc"
}
// encodeConnectorConfig serializes connector config to JSON bytes.
func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
return json.Marshal(config)
}
// decodeConnectorConfig deserializes connector config from JSON bytes.
func decodeConnectorConfig(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// GetRedirectURI returns the default redirect URI for connectors.
func (p *Provider) GetRedirectURI() string {
if p.config == nil {
return ""
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !strings.HasSuffix(issuer, "/oauth2") {
issuer += "/oauth2"
}
return issuer + "/callback"
return nil
}
// GetIssuer returns the OIDC issuer URL.

View File

@@ -82,16 +82,6 @@ read_nb_domain() {
return 0
}
get_turn_external_ip() {
TURN_EXTERNAL_IP_CONFIG="#external-ip="
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
if [[ "x-$IP" != "x-" ]]; then
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
fi
echo "$TURN_EXTERNAL_IP_CONFIG"
return 0
}
read_reverse_proxy_type() {
echo "" > /dev/stderr
echo "Which reverse proxy will you use?" > /dev/stderr
@@ -249,14 +239,17 @@ initialize_default_values() {
NETBIRD_PORT=80
NETBIRD_HTTP_PROTOCOL="http"
NETBIRD_RELAY_PROTO="rel"
TURN_USER="self"
TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
# Note: DataStoreEncryptionKey must keep base64 padding (=) for Go's base64.StdEncoding
DATASTORE_ENCRYPTION_KEY=$(openssl rand -base64 32)
TURN_MIN_PORT=49152
TURN_MAX_PORT=65535
TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip)
NETBIRD_STUN_PORT=3478
# Docker images
CADDY_IMAGE="caddy"
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
SIGNAL_IMAGE="netbirdio/signal:latest"
RELAY_IMAGE="netbirdio/relay:latest"
MANAGEMENT_IMAGE="netbirdio/management:latest"
# Reverse proxy configuration
REVERSE_PROXY_TYPE="0"
@@ -320,7 +313,7 @@ check_existing_installation() {
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
echo " rm -f docker-compose.yml Caddyfile dashboard.env turnserver.conf management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1
fi
@@ -363,7 +356,6 @@ generate_configuration_files() {
# Common files for all configurations
render_dashboard_env > dashboard.env
render_management_json > management.json
render_turn_server_conf > turnserver.conf
render_relay_env > relay.env
return 0
}
@@ -487,34 +479,13 @@ EOF
return 0
}
render_turn_server_conf() {
cat <<EOF
listening-port=3478
$TURN_EXTERNAL_IP_CONFIG
tls-listening-port=5349
min-port=$TURN_MIN_PORT
max-port=$TURN_MAX_PORT
fingerprint
lt-cred-mech
user=$TURN_USER:$TURN_PASSWORD
realm=wiretrustee.com
cert=/etc/coturn/certs/cert.pem
pkey=/etc/coturn/private/privkey.pem
log-file=stdout
no-software-attribute
pidfile="/var/tmp/turnserver.pid"
no-cli
EOF
return 0
}
render_management_json() {
cat <<EOF
{
"Stuns": [
{
"Proto": "udp",
"URI": "stun:$NETBIRD_DOMAIN:3478"
"URI": "stun:$NETBIRD_DOMAIN:$NETBIRD_STUN_PORT"
}
],
"Relay": {
@@ -569,6 +540,9 @@ NB_LOG_LEVEL=info
NB_LISTEN_ADDRESS=:80
NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT
NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
NB_ENABLE_STUN=true
NB_STUN_LOG_LEVEL=info
NB_STUN_PORTS=$NETBIRD_STUN_PORT
EOF
return 0
}
@@ -578,7 +552,7 @@ render_docker_compose() {
services:
# Caddy reverse proxy
caddy:
image: caddy
image: $CADDY_IMAGE
container_name: netbird-caddy
restart: unless-stopped
networks: [netbird]
@@ -597,7 +571,7 @@ services:
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
image: $DASHBOARD_IMAGE
container_name: netbird-dashboard
restart: unless-stopped
networks: [netbird]
@@ -611,7 +585,7 @@ services:
# Signal
signal:
image: netbirdio/signal:latest
image: $SIGNAL_IMAGE
container_name: netbird-signal
restart: unless-stopped
networks: [netbird]
@@ -621,12 +595,14 @@ services:
max-size: "500m"
max-file: "2"
# Relay
# Relay (includes embedded STUN server)
relay:
image: netbirdio/relay:latest
image: $RELAY_IMAGE
container_name: netbird-relay
restart: unless-stopped
networks: [netbird]
ports:
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
env_file:
- ./relay.env
logging:
@@ -637,7 +613,7 @@ services:
# Management (includes embedded IdP)
management:
image: netbirdio/management:latest
image: $MANAGEMENT_IMAGE
container_name: netbird-management
restart: unless-stopped
networks: [netbird]
@@ -659,22 +635,6 @@ services:
max-size: "500m"
max-file: "2"
# Coturn, AKA TURN server
coturn:
image: coturn/coturn
container_name: netbird-coturn
restart: unless-stopped
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_caddy_data:
netbird_management:
@@ -702,7 +662,7 @@ render_docker_compose_traefik() {
services:
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
image: $DASHBOARD_IMAGE
container_name: netbird-dashboard
restart: unless-stopped
networks: [$network_name]
@@ -724,7 +684,7 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-das
# Signal
signal:
image: netbirdio/signal:latest
image: $SIGNAL_IMAGE
container_name: netbird-signal
restart: unless-stopped
networks: [$network_name]
@@ -751,12 +711,14 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-sig
max-size: "500m"
max-file: "2"
# Relay
# Relay (includes embedded STUN server)
relay:
image: netbirdio/relay:latest
image: $RELAY_IMAGE
container_name: netbird-relay
restart: unless-stopped
networks: [$network_name]
ports:
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
env_file:
- ./relay.env
labels:
@@ -774,7 +736,7 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-rel
# Management (includes embedded IdP)
management:
image: netbirdio/management:latest
image: $MANAGEMENT_IMAGE
container_name: netbird-management
restart: unless-stopped
networks: [$network_name]
@@ -827,24 +789,6 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-oau
max-size: "500m"
max-file: "2"
# Coturn, AKA TURN server
coturn:
image: coturn/coturn
container_name: netbird-coturn
restart: unless-stopped
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host
labels:
- traefik.enable=false
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_management:
@@ -874,7 +818,7 @@ render_docker_compose_exposed_ports() {
services:
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
image: $DASHBOARD_IMAGE
container_name: netbird-dashboard
restart: unless-stopped
networks: ${networks}
@@ -890,7 +834,7 @@ services:
# Signal
signal:
image: netbirdio/signal:latest
image: $SIGNAL_IMAGE
container_name: netbird-signal
restart: unless-stopped
networks: ${networks}
@@ -903,14 +847,15 @@ services:
max-size: "500m"
max-file: "2"
# Relay
# Relay (includes embedded STUN server)
relay:
image: netbirdio/relay:latest
image: $RELAY_IMAGE
container_name: netbird-relay
restart: unless-stopped
networks: ${networks}
ports:
- '${bind_addr}:${RELAY_HOST_PORT}:80'
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
env_file:
- ./relay.env
logging:
@@ -921,7 +866,7 @@ services:
# Management (includes embedded IdP)
management:
image: netbirdio/management:latest
image: $MANAGEMENT_IMAGE
container_name: netbird-management
restart: unless-stopped
networks: ${networks}
@@ -945,22 +890,6 @@ services:
max-size: "500m"
max-file: "2"
# Coturn, AKA TURN server
coturn:
image: coturn/coturn
container_name: netbird-coturn
restart: unless-stopped
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_management:

View File

@@ -856,3 +856,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer) {
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
}

View File

@@ -36,4 +36,6 @@ type Controller interface {
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer)
}

View File

@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Source: management/internals/controllers/network_map/interface.go
//
// Generated by this command:
//
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
// mockgen -package network_map -destination=management/internals/controllers/network_map/interface_mock.go -source=management/internals/controllers/network_map/interface.go -build_flags=-mod=mod
//
// Package network_map is a generated GoMock package.
@@ -211,6 +211,18 @@ func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
}
// TrackEphemeralPeer mocks base method.
func (m *MockController) TrackEphemeralPeer(ctx context.Context, arg1 *peer.Peer) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "TrackEphemeralPeer", ctx, arg1)
}
// TrackEphemeralPeer indicates an expected call of TrackEphemeralPeer.
func (mr *MockControllerMockRecorder) TrackEphemeralPeer(ctx, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrackEphemeralPeer", reflect.TypeOf((*MockController)(nil).TrackEphemeralPeer), ctx, arg1)
}
// UpdateAccountPeer mocks base method.
func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
m.ctrl.T.Helper()

View File

@@ -31,6 +31,7 @@ type Manager interface {
SetNetworkMapController(networkMapController network_map.Controller)
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
}
type managerImpl struct {
@@ -167,3 +168,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return nil
}
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
}

View File

@@ -97,6 +97,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeerID mocks base method.
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerID", ctx, peerKey)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerID indicates an expected call of GetPeerID.
func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -144,7 +144,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -16,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
@@ -24,6 +26,12 @@ func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
})
}
func (s *BaseServer) JobManager() *job.Manager {
return Create(s, func() *job.Manager {
return job.NewJobManager(s.Metrics(), s.Store(), s.PeersManager())
})
}
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(

View File

@@ -87,7 +87,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}

View File

@@ -195,6 +195,7 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
@@ -26,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -57,6 +59,7 @@ type Server struct {
accountManager account.Manager
settingsManager settings.Manager
proto.UnimplementedManagementServiceServer
jobManager *job.Manager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
@@ -82,6 +85,7 @@ func NewServer(
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
jobManager *job.Manager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
authManager auth.Manager,
@@ -114,6 +118,7 @@ func NewServer(
}
return &Server{
jobManager: jobManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
@@ -169,6 +174,40 @@ func getRealIP(ctx context.Context) net.IP {
return nil
}
func (s *Server) Job(srv proto.ManagementService_JobServer) error {
reqStart := time.Now()
ctx := srv.Context()
peerKey, err := s.handleHandshake(ctx, srv)
if err != nil {
return err
}
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return status.Errorf(codes.Unauthenticated, "peer is not registered")
}
s.startResponseReceiver(ctx, srv)
updates := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID)
log.WithContext(ctx).Debugf("Job: took %v", time.Since(reqStart))
return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv)
}
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
@@ -289,6 +328,70 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
hello, err := srv.Recv()
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "missing hello: %v", err)
}
jobReq := &proto.JobRequest{}
peerKey, err := s.parseRequest(ctx, hello, jobReq)
if err != nil {
return wgtypes.Key{}, err
}
return peerKey, nil
}
func (s *Server) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) {
go func() {
for {
msg, err := srv.Recv()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return
}
log.WithContext(ctx).Warnf("recv job response error: %v", err)
return
}
jobResp := &proto.JobResponse{}
if _, err := s.parseRequest(ctx, msg, jobResp); err != nil {
log.WithContext(ctx).Warnf("invalid job response: %v", err)
continue
}
if err := s.jobManager.HandleResponse(ctx, jobResp, msg.WgPubKey); err != nil {
log.WithContext(ctx).Errorf("handle job response failed: %v", err)
}
}
}()
}
func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates *job.Channel, srv proto.ManagementService_JobServer) error {
// todo figure out better error handling strategy
defer s.jobManager.CloseChannel(ctx, accountID, peer.ID)
for {
event, err := updates.Event(ctx)
if err != nil {
if errors.Is(err, job.ErrJobChannelClosed) {
log.WithContext(ctx).Debugf("jobs channel for peer %s was closed", peerKey.String())
return nil
}
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
return ctx.Err()
}
if err := s.sendJob(ctx, peerKey, event, srv); err != nil {
log.WithContext(ctx).Warnf("send job failed: %v", err)
return nil
}
}
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
@@ -306,7 +409,6 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
@@ -336,7 +438,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
@@ -348,6 +450,31 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
return nil
}
// sendJob encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Event, srv proto.ManagementService_JobServer) error {
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Errorf("failed to get wg key for peer %s: %v", peerKey.String(), err)
return status.Errorf(codes.Internal, "failed processing job message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, job.Request)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt job for peer %s: %v", peerKey.String(), err)
return status.Errorf(codes.Internal, "failed processing job message")
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
return status.Errorf(codes.Internal, "failed sending job message")
}
log.WithContext(ctx).Debugf("sent a job to peer: %s", peerKey.String())
return nil
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
@@ -690,8 +817,8 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
var err error
var turnToken *Token
if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
turnToken, err = s.secretsManager.GenerateTurnToken()
if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
cacheStore "github.com/eko/gocache/lib/v4/store"
@@ -70,6 +71,7 @@ type DefaultAccountManager struct {
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
networkMapController network_map.Controller
jobManager *job.Manager
idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache
@@ -178,6 +180,7 @@ func BuildManager(
config *nbconfig.Config,
store store.Store,
networkMapController network_map.Controller,
jobManager *job.Manager,
idpManager idp.Manager,
singleAccountModeDomain string,
eventStore activity.Store,
@@ -200,6 +203,7 @@ func BuildManager(
config: config,
geo: geo,
networkMapController: networkMapController,
jobManager: jobManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},

View File

@@ -32,6 +32,7 @@ type Manager interface {
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
@@ -129,4 +130,7 @@ type Manager interface {
CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -3023,13 +3024,14 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}

View File

@@ -195,6 +195,10 @@ const (
DNSRecordUpdated Activity = 100
DNSRecordDeleted Activity = 101
JobCreatedByUser Activity = 102
UserPasswordChanged Activity = 103
AccountDeleted Activity = 99999
)
@@ -319,6 +323,10 @@ var activityMap = map[Activity]Code{
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
UserPasswordChanged: {"User password changed", "user.password.change"},
}
// StringCode returns a string code of the activity

View File

@@ -1,136 +0,0 @@
package store
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
gcm cipher.AEAD
}
func GenerateKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
return "", err
}
readableKey := base64.StdEncoding.EncodeToString(key)
return readableKey, nil
}
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(binKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{
block: block,
gcm: gcm,
}
return ec, nil
}
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText)
}
// Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, cipherText)
payload, err := pkcs5UnPadding(cipherText)
if err != nil {
return "", err
}
return string(payload), nil
}
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
if srcLen == 0 {
return nil, errors.New("input data is empty")
}
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -1,310 +0,0 @@
package store
import (
"bytes"
"testing"
)
func TestGenerateKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.Decrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
newKey, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err = NewFieldEncrypt(newKey)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
res, _ := ee.Decrypt(encrypted)
if res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -10,9 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/migration"
"github.com/netbirdio/netbird/util/crypt"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
func migrate(ctx context.Context, crypt *crypt.FieldEncrypt, db *gorm.DB) error {
migrations := getMigrations(ctx, crypt)
for _, m := range migrations {
@@ -26,7 +27,7 @@ func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
type migrationFunc func(*gorm.DB) error
func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
func getMigrations(ctx context.Context, crypt *crypt.FieldEncrypt) []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "")
@@ -45,7 +46,7 @@ func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error {
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *crypt.FieldEncrypt) error {
model := &activity.DeletedUser{}
if !db.Migrator().HasTable(model) {
@@ -80,7 +81,7 @@ func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *F
return nil
}
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error {
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *crypt.FieldEncrypt) error {
var err error
var decryptedEmail, decryptedName string

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/migration"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/util/crypt"
)
const (
@@ -40,10 +41,10 @@ func setupDatabase(t *testing.T) *gorm.DB {
func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
key, err := crypt.GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
crypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
t.Run("empty table, no migration required", func(t *testing.T) {

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util/crypt"
)
const (
@@ -45,12 +46,12 @@ type eventWithNames struct {
// Store is the implementation of the activity.Store interface backed by SQLite
type Store struct {
db *gorm.DB
fieldEncrypt *FieldEncrypt
fieldEncrypt *crypt.FieldEncrypt
}
// NewSqlStore creates a new Store with an event table if not exists.
func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
crypt, err := NewFieldEncrypt(encryptionKey)
fieldEncrypt, err := crypt.NewFieldEncrypt(encryptionKey)
if err != nil {
return nil, err
@@ -61,7 +62,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St
return nil, fmt.Errorf("initialize database: %w", err)
}
if err = migrate(ctx, crypt, db); err != nil {
if err = migrate(ctx, fieldEncrypt, db); err != nil {
return nil, fmt.Errorf("events database migration: %w", err)
}
@@ -72,7 +73,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St
return &Store{
db: db,
fieldEncrypt: crypt,
fieldEncrypt: fieldEncrypt,
}, nil
}

View File

@@ -9,11 +9,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util/crypt"
)
func TestNewSqlStore(t *testing.T) {
dataDir := t.TempDir()
key, _ := GenerateKey()
key, _ := crypt.GenerateKey()
store, err := NewSqlStore(context.Background(), dataDir, key)
if err != nil {
t.Fatal(err)

View File

@@ -16,6 +16,7 @@ import (
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -221,13 +222,14 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -36,6 +36,9 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
}
// NewHandler creates a new peers Handler
@@ -46,6 +49,99 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
}
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
req := &api.JobRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil {
util.WriteError(ctx, err, w)
return
}
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
respBody := make([]*api.JobResponse, 0, len(jobs))
for _, job := range jobs {
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
respBody = append(respBody, resp)
}
util.WriteJSONObject(ctx, w, respBody)
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
@@ -521,6 +617,28 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
}
}
func toSingleJobResponse(job *types.Job) (*api.JobResponse, error) {
workload, err := job.BuildWorkloadResponse()
if err != nil {
return nil, err
}
var failed *string
if job.FailedReason != "" {
failed = &job.FailedReason
}
return &api.JobResponse{
Id: job.ID,
CreatedAt: job.CreatedAt,
CompletedAt: job.CompletedAt,
TriggeredBy: job.TriggeredBy,
Status: api.JobResponseStatus(job.Status),
FailedReason: failed,
Workload: *workload,
}, nil
}
func fqdn(peer *nbpeer.Peer, dnsDomain string) string {
fqdn := peer.FQDN(dnsDomain)
if fqdn == "" {

View File

@@ -33,6 +33,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
}
@@ -410,3 +411,46 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// passwordChangeRequest represents the request body for password change
type passwordChangeRequest struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
// changePassword is a PUT request to change user's password.
// Only available when embedded IDP is enabled.
// Users can only change their own password.
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req passwordChangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -856,3 +856,118 @@ func TestRejectUserEndpoint(t *testing.T) {
})
}
}
func TestChangePasswordEndpoint(t *testing.T) {
tt := []struct {
name string
expectedStatus int
requestBody string
targetUserID string
currentUserID string
mockError error
expectMockNotCalled bool
}{
{
name: "successful password change",
expectedStatus: http.StatusOK,
requestBody: `{"old_password": "OldPass123!", "new_password": "NewPass456!"}`,
targetUserID: existingUserID,
currentUserID: existingUserID,
mockError: nil,
},
{
name: "missing old password",
expectedStatus: http.StatusUnprocessableEntity,
requestBody: `{"new_password": "NewPass456!"}`,
targetUserID: existingUserID,
currentUserID: existingUserID,
mockError: status.Errorf(status.InvalidArgument, "old password is required"),
},
{
name: "missing new password",
expectedStatus: http.StatusUnprocessableEntity,
requestBody: `{"old_password": "OldPass123!"}`,
targetUserID: existingUserID,
currentUserID: existingUserID,
mockError: status.Errorf(status.InvalidArgument, "new password is required"),
},
{
name: "wrong old password",
expectedStatus: http.StatusUnprocessableEntity,
requestBody: `{"old_password": "WrongPass!", "new_password": "NewPass456!"}`,
targetUserID: existingUserID,
currentUserID: existingUserID,
mockError: status.Errorf(status.InvalidArgument, "invalid password"),
},
{
name: "embedded IDP not enabled",
expectedStatus: http.StatusPreconditionFailed,
requestBody: `{"old_password": "OldPass123!", "new_password": "NewPass456!"}`,
targetUserID: existingUserID,
currentUserID: existingUserID,
mockError: status.Errorf(status.PreconditionFailed, "password change is only available with embedded identity provider"),
},
{
name: "invalid JSON request",
expectedStatus: http.StatusBadRequest,
requestBody: `{invalid json}`,
targetUserID: existingUserID,
currentUserID: existingUserID,
expectMockNotCalled: true,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
mockCalled := false
am := &mock_server.MockAccountManager{}
am.UpdateUserPasswordFunc = func(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error {
mockCalled = true
return tc.mockError
}
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
reqPath := "/users/" + tc.targetUserID + "/password"
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
require.NoError(t, err)
userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.currentUserID,
}
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectMockNotCalled {
assert.False(t, mockCalled, "mock should not have been called")
}
})
}
}
func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
am := &mock_server.MockAccountManager{}
handler := newHandler(am)
req, err := http.NewRequest("POST", "/users/test-user/password", bytes.NewBufferString(`{}`))
require.NoError(t, err)
userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: existingUserID,
}
req = nbcontext.SetUserAuthInRequest(req, userAuth)
rr := httptest.NewRecorder()
handler.changePassword(rr, req)
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
@@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
@@ -72,11 +74,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
@@ -94,7 +99,6 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -80,11 +81,12 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update
AnyTimes()
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peers.NewManager(testStore, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}

View File

@@ -400,7 +400,6 @@ func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email,
// InviteUserByID resends an invitation to a user.
func (m *EmbeddedIdPManager) InviteUserByID(ctx context.Context, userID string) error {
// TODO: implement
return fmt.Errorf("not implemented")
}
@@ -432,6 +431,33 @@ func (m *EmbeddedIdPManager) DeleteUser(ctx context.Context, userID string) erro
return nil
}
// UpdateUserPassword updates the password for a user in the embedded IdP.
// It verifies that the current user is changing their own password and
// validates the current password before updating to the new password.
func (m *EmbeddedIdPManager) UpdateUserPassword(ctx context.Context, currentUserID, targetUserID string, oldPassword, newPassword string) error {
// Verify the user is changing their own password
if currentUserID != targetUserID {
return fmt.Errorf("users can only change their own password")
}
// Verify the new password is different from the old password
if oldPassword == newPassword {
return fmt.Errorf("new password must be different from current password")
}
err := m.provider.UpdateUserPassword(ctx, targetUserID, oldPassword, newPassword)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
log.WithContext(ctx).Debugf("updated password for user %s in embedded IdP", targetUserID)
return nil
}
// CreateConnector creates a new identity provider connector in Dex.
// Returns the created connector config with the redirect URL populated.
func (m *EmbeddedIdPManager) CreateConnector(ctx context.Context, cfg *dex.ConnectorConfig) (*dex.ConnectorConfig, error) {
@@ -449,15 +475,8 @@ func (m *EmbeddedIdPManager) ListConnectors(ctx context.Context) ([]*dex.Connect
}
// UpdateConnector updates an existing identity provider connector.
// Field preservation for partial updates is handled by Provider.UpdateConnector.
func (m *EmbeddedIdPManager) UpdateConnector(ctx context.Context, cfg *dex.ConnectorConfig) error {
// Preserve existing secret if not provided in update
if cfg.ClientSecret == "" {
existing, err := m.provider.GetConnector(ctx, cfg.ID)
if err != nil {
return fmt.Errorf("failed to get existing connector: %w", err)
}
cfg.ClientSecret = existing.ClientSecret
}
return m.provider.UpdateConnector(ctx, cfg)
}

View File

@@ -248,6 +248,71 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
t.Logf(" Connector: %s", connectorID)
}
func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Create a user with a known password
email := "password-test@example.com"
name := "Password Test User"
initialPassword := "InitialPass123!"
userData, err := manager.CreateUserWithPassword(ctx, email, initialPassword, name)
require.NoError(t, err)
require.NotNil(t, userData)
userID := userData.ID
t.Run("successful password change", func(t *testing.T) {
newPassword := "NewSecurePass456!"
err := manager.UpdateUserPassword(ctx, userID, userID, initialPassword, newPassword)
require.NoError(t, err)
// Verify the new password works by changing it again
anotherPassword := "AnotherPass789!"
err = manager.UpdateUserPassword(ctx, userID, userID, newPassword, anotherPassword)
require.NoError(t, err)
})
t.Run("wrong old password", func(t *testing.T) {
err := manager.UpdateUserPassword(ctx, userID, userID, "wrongpassword", "NewPass123!")
require.Error(t, err)
assert.Contains(t, err.Error(), "current password is incorrect")
})
t.Run("cannot change other user password", func(t *testing.T) {
otherUserID := "other-user-id"
err := manager.UpdateUserPassword(ctx, userID, otherUserID, "oldpass", "newpass")
require.Error(t, err)
assert.Contains(t, err.Error(), "users can only change their own password")
})
t.Run("same password rejected", func(t *testing.T) {
samePassword := "SamePass123!"
err := manager.UpdateUserPassword(ctx, userID, userID, samePassword, samePassword)
require.Error(t, err)
assert.Contains(t, err.Error(), "new password must be different")
})
}
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
ctx := context.Background()

View File

@@ -0,0 +1,59 @@
package job
import (
"context"
"errors"
"fmt"
"sync"
"time"
)
// todo consider the channel buffer size when we allow to run multiple jobs
const jobChannelBuffer = 1
var (
ErrJobChannelClosed = errors.New("job channel closed")
)
type Channel struct {
events chan *Event
once sync.Once
}
func NewChannel() *Channel {
jc := &Channel{
events: make(chan *Event, jobChannelBuffer),
}
return jc
}
func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) error {
select {
case <-ctx.Done():
return ctx.Err()
// todo: timeout is handled in the wrong place. If the peer does not respond with the job response, the server does not clean it up from the pending jobs and cannot apply a new job
case <-time.After(responseWait):
return fmt.Errorf("failed to add the event to the channel")
case jc.events <- event:
}
return nil
}
func (jc *Channel) Close() {
jc.once.Do(func() {
close(jc.events)
})
}
func (jc *Channel) Event(ctx context.Context) (*Event, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case job, open := <-jc.events:
if !open {
return nil, ErrJobChannelClosed
}
return job, nil
}
}

View File

@@ -0,0 +1,182 @@
package job
import (
"context"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Event struct {
PeerID string
Request *proto.JobRequest
Response *proto.JobResponse
}
type Manager struct {
mu *sync.RWMutex
jobChannels map[string]*Channel // per-peer job streams
pending map[string]*Event // jobID → event
responseWait time.Duration
metrics telemetry.AppMetrics
Store store.Store
peersManager peers.Manager
}
func NewJobManager(metrics telemetry.AppMetrics, store store.Store, peersManager peers.Manager) *Manager {
return &Manager{
jobChannels: make(map[string]*Channel),
pending: make(map[string]*Event),
responseWait: 5 * time.Minute,
metrics: metrics,
mu: &sync.RWMutex{},
Store: store,
peersManager: peersManager,
}
}
// CreateJobChannel creates or replaces a channel for a peer
func (jm *Manager) CreateJobChannel(ctx context.Context, accountID, peerID string) *Channel {
// all pending jobs stored in db for this peer should be failed
if err := jm.Store.MarkAllPendingJobsAsFailed(ctx, accountID, peerID, "Pending job cleanup: marked as failed automatically due to being stuck too long"); err != nil {
log.WithContext(ctx).Error(err.Error())
}
jm.mu.Lock()
defer jm.mu.Unlock()
if ch, ok := jm.jobChannels[peerID]; ok {
ch.Close()
delete(jm.jobChannels, peerID)
}
ch := NewChannel()
jm.jobChannels[peerID] = ch
return ch
}
// SendJob sends a job to a peer and tracks it as pending
func (jm *Manager) SendJob(ctx context.Context, accountID, peerID string, req *proto.JobRequest) error {
jm.mu.RLock()
ch, ok := jm.jobChannels[peerID]
jm.mu.RUnlock()
if !ok {
return fmt.Errorf("peer %s has no channel", peerID)
}
event := &Event{
PeerID: peerID,
Request: req,
}
jm.mu.Lock()
jm.pending[string(req.ID)] = event
jm.mu.Unlock()
if err := ch.AddEvent(ctx, jm.responseWait, event); err != nil {
jm.cleanup(ctx, accountID, string(req.ID), err.Error())
return err
}
return nil
}
// HandleResponse marks a job as finished and moves it to completed
func (jm *Manager) HandleResponse(ctx context.Context, resp *proto.JobResponse, peerKey string) error {
jm.mu.Lock()
defer jm.mu.Unlock()
// todo: validate job ID and would be nice to use uuid text marshal instead of string
jobID := string(resp.ID)
// todo: in this map has jobs for all peers in any account. Consider to validate the jobID association for the peer
event, ok := jm.pending[jobID]
if !ok {
return fmt.Errorf("job %s not found", jobID)
}
var job types.Job
// todo: ApplyResponse should be static. Any member value is unusable in this way
if err := job.ApplyResponse(resp); err != nil {
return fmt.Errorf("invalid job response: %v", err)
}
peerID, err := jm.peersManager.GetPeerID(ctx, peerKey)
if err != nil {
return fmt.Errorf("failed to get peer ID: %v", err)
}
if peerID != event.PeerID {
return fmt.Errorf("peer ID mismatch: %s != %s", peerID, event.PeerID)
}
// update or create the store for job response
err = jm.Store.CompletePeerJob(ctx, &job)
if err != nil {
return fmt.Errorf("failed to complete job %s: %v", jobID, err)
}
delete(jm.pending, jobID)
return nil
}
// CloseChannel closes a peers channel and cleans up its jobs
func (jm *Manager) CloseChannel(ctx context.Context, accountID, peerID string) {
jm.mu.Lock()
defer jm.mu.Unlock()
if ch, ok := jm.jobChannels[peerID]; ok {
ch.Close()
delete(jm.jobChannels, peerID)
}
for jobID, ev := range jm.pending {
if ev.PeerID == peerID {
// if the client disconnect and there is pending job then mark it as failed
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, jobID, "Time out peer disconnected"); err != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as failed: %v", err)
}
delete(jm.pending, jobID)
}
}
}
// cleanup removes a pending job safely
func (jm *Manager) cleanup(ctx context.Context, accountID, jobID string, reason string) {
jm.mu.Lock()
defer jm.mu.Unlock()
if ev, ok := jm.pending[jobID]; ok {
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, ev.PeerID, jobID, reason); err != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as failed: %v", err)
}
delete(jm.pending, jobID)
}
}
func (jm *Manager) IsPeerConnected(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
_, ok := jm.jobChannels[peerID]
return ok
}
func (jm *Manager) IsPeerHasPendingJobs(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
for _, ev := range jm.pending {
if ev.PeerID == peerID {
return true
}
}
return false
}

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -361,13 +362,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "",
accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
@@ -381,7 +384,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, nil, "", cleanup, err
}

View File

@@ -30,6 +30,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -202,6 +203,8 @@ func startServer(
AnyTimes()
permissionsManager := permissions.NewManager(str)
peersManager := peers.NewManager(str, permissionsManager)
jobManager := job.NewJobManager(nil, str, peersManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
@@ -213,6 +216,7 @@ func startServer(
nil,
str,
networkMapController,
jobManager,
nil,
"",
eventStore,
@@ -237,6 +241,7 @@ func startServer(
config,
accountManager,
settingsMockManager,
jobManager,
secretsManager,
nil,
nil,

View File

@@ -74,6 +74,7 @@ type MockAccountManager struct {
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
UpdateUserPasswordFunc func(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
@@ -135,6 +136,29 @@ type MockAccountManager struct {
CreateIdentityProviderFunc func(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
UpdateIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
DeleteIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) error
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
if am.CreatePeerJobFunc != nil {
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
}
return status.Errorf(codes.Unimplemented, "method CreatePeerJob is not implemented")
}
func (am *MockAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
if am.GetAllPeerJobsFunc != nil {
return am.GetAllPeerJobsFunc(ctx, accountID, userID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllPeerJobs is not implemented")
}
func (am *MockAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
if am.GetPeerJobByIDFunc != nil {
return am.GetPeerJobByIDFunc(ctx, accountID, userID, peerID, jobID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerJobByID is not implemented")
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -612,6 +636,14 @@ func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID,
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
}
// UpdateUserPassword mocks UpdateUserPassword of the AccountManager interface
func (am *MockAccountManager) UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error {
if am.UpdateUserPasswordFunc != nil {
return am.UpdateUserPasswordFunc(ctx, accountID, currentUserID, targetUserID, oldPassword, newPassword)
}
return status.Errorf(codes.Unimplemented, "method UpdateUserPassword is not implemented")
}
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.InviteUserFunc != nil {
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -790,13 +791,14 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -31,6 +31,8 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const remoteJobsMinVer = "0.64.0"
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
@@ -324,6 +326,134 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return peer, nil
}
func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
if p.AccountID != accountID {
return status.NewPeerNotPartOfAccountError()
}
meetMinVer, err := posture.MeetsMinVersion(remoteJobsMinVer, p.Meta.WtVersion)
if !strings.Contains(p.Meta.WtVersion, "dev") && (!meetMinVer || err != nil) {
return status.Errorf(status.PreconditionFailed, "peer version %s does not meet the minimum required version %s for remote jobs", p.Meta.WtVersion, remoteJobsMinVer)
}
if !am.jobManager.IsPeerConnected(peerID) {
return status.Errorf(status.BadRequest, "peer not connected")
}
// check if already has pending jobs
// todo: The job checks here are not protected. The user can run this function from multiple threads,
// and each thread can think there is no job yet. This means entries in the pending job map will be overwritten,
// and only one will be kept, but potentially another one will overwrite it in the queue.
if am.jobManager.IsPeerHasPendingJobs(peerID) {
return status.Errorf(status.BadRequest, "peer already has pending job")
}
jobStream, err := job.ToStreamJobRequest()
if err != nil {
return status.Errorf(status.BadRequest, "invalid job request %v", err)
}
// try sending job first
if err := am.jobManager.SendJob(ctx, accountID, peerID, jobStream); err != nil {
return status.Errorf(status.Internal, "failed to send job: %v", err)
}
var peer *nbpeer.Peer
var eventsToStore func()
// persist job in DB only if send succeeded
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return err
}
if err := transaction.CreatePeerJob(ctx, job); err != nil {
return err
}
jobMeta := map[string]any{
"for_peer_name": peer.Name,
"job_type": job.Workload.Type,
}
eventsToStore = func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.JobCreatedByUser, jobMeta)
}
return nil
})
if err != nil {
return err
}
eventsToStore()
return nil
}
func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return nil, err
}
if peerAccountID != accountID {
return nil, status.NewPeerNotPartOfAccountError()
}
accountJobs, err := am.Store.GetPeerJobs(ctx, accountID, peerID)
if err != nil {
return nil, err
}
return accountJobs, nil
}
func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return nil, err
}
if peerAccountID != accountID {
return nil, status.NewPeerNotPartOfAccountError()
}
job, err := am.Store.GetPeerJobByID(ctx, accountID, jobID)
if err != nil {
return nil, err
}
return job, nil
}
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
@@ -598,6 +728,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if temporary {
// we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {

View File

@@ -34,6 +34,7 @@ import (
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
@@ -1289,13 +1290,14 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1374,13 +1376,14 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1527,13 +1530,14 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1607,13 +1611,14 @@ func Test_LoginPeer(t *testing.T) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

View File

@@ -3,35 +3,37 @@ package modules
type Module string
const (
Networks Module = "networks"
Peers Module = "peers"
Groups Module = "groups"
Settings Module = "settings"
Accounts Module = "accounts"
Dns Module = "dns"
Nameservers Module = "nameservers"
Events Module = "events"
Policies Module = "policies"
Routes Module = "routes"
Users Module = "users"
SetupKeys Module = "setup_keys"
Pats Module = "pats"
Networks Module = "networks"
Peers Module = "peers"
RemoteJobs Module = "remote_jobs"
Groups Module = "groups"
Settings Module = "settings"
Accounts Module = "accounts"
Dns Module = "dns"
Nameservers Module = "nameservers"
Events Module = "events"
Policies Module = "policies"
Routes Module = "routes"
Users Module = "users"
SetupKeys Module = "setup_keys"
Pats Module = "pats"
IdentityProviders Module = "identity_providers"
)
var All = map[Module]struct{}{
Networks: {},
Peers: {},
Groups: {},
Settings: {},
Accounts: {},
Dns: {},
Nameservers: {},
Events: {},
Policies: {},
Routes: {},
Users: {},
SetupKeys: {},
Pats: {},
Networks: {},
Peers: {},
RemoteJobs: {},
Groups: {},
Settings: {},
Accounts: {},
Dns: {},
Nameservers: {},
Events: {},
Policies: {},
Routes: {},
Users: {},
SetupKeys: {},
Pats: {},
IdentityProviders: {},
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -1289,13 +1290,14 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
Return(&types.ExtraSettings{}, nil)
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}

View File

@@ -43,14 +43,15 @@ import (
)
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
mysqlKeyQueryCondition = "`key` = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
mysqlKeyQueryCondition = "`key` = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountAndPeerIDQueryCondition = "account_id = ? and peer_id = ?"
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
pgMaxConnections = 30
pgMinConnections = 1
@@ -125,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&zones.Zone{}, &records.Record{},
&types.Job{}, &zones.Zone{}, &records.Record{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -144,6 +145,97 @@ func GetKeyQueryCondition(s *SqlStore) string {
return keyQueryCondition
}
// SaveJob persists a job in DB
func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error {
result := s.db.Create(job)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create job in store")
}
return nil
}
func (s *SqlStore) CompletePeerJob(ctx context.Context, job *types.Job) error {
result := s.db.
Model(&types.Job{}).
Where(idQueryCondition, job.ID).
Updates(job)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to update job in store")
}
return nil
}
// job was pending for too long and has been cancelled
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error {
now := time.Now().UTC()
result := s.db.
Model(&types.Job{}).
Where(accountAndPeerIDQueryCondition+" AND id = ?"+" AND status = ?", accountID, peerID, jobID, types.JobStatusPending).
Updates(types.Job{
Status: types.JobStatusFailed,
FailedReason: reason,
CompletedAt: &now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
}
return nil
}
// job was pending for too long and has been cancelled
func (s *SqlStore) MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error {
now := time.Now().UTC()
result := s.db.
Model(&types.Job{}).
Where(accountAndPeerIDQueryCondition+" AND status = ?", accountID, peerID, types.JobStatusPending).
Updates(types.Job{
Status: types.JobStatusFailed,
FailedReason: reason,
CompletedAt: &now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
}
return nil
}
// GetJobByID fetches job by ID
func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) {
var job types.Job
err := s.db.
Where(accountAndIDQueryCondition, accountID, jobID).
First(&job).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "job %s not found", jobID)
}
if err != nil {
log.WithContext(ctx).Errorf("failed to fetch job from store: %s", err)
return nil, err
}
return &job, nil
}
// get all jobs
func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) {
var jobs []*types.Job
err := s.db.
Where(accountAndPeerIDQueryCondition, accountID, peerID).
Order("created_at DESC").
Find(&jobs).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err)
return nil, err
}
return jobs, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock")
@@ -4363,3 +4455,23 @@ func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID s
return nil
}
func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peerID string
result := tx.Model(&nbpeer.Peer{}).
Select("id").
Where(GetKeyQueryCondition(s), key).
Limit(1).
Scan(&peerID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get peer ID by key: %s", result.Error)
return "", status.Errorf(status.Internal, "failed to get peer ID by key")
}
return peerID, nil
}

View File

@@ -226,6 +226,13 @@ type Store interface {
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
CreatePeerJob(ctx context.Context, job *types.Job) error
CompletePeerJob(ctx context.Context, job *types.Job) error
GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error)
GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error)
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
}
const (

Some files were not shown because too many files have changed in this diff Show More