mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 06:06:38 +00:00
Compare commits
21 Commits
feature/di
...
v0.64.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67211010f7 | ||
|
|
c61568ceb4 | ||
|
|
737d6061bf | ||
|
|
ee3a67d2d8 | ||
|
|
1a32e4c223 | ||
|
|
269d5d1cba | ||
|
|
a1de2b8a98 | ||
|
|
d0221a3e72 | ||
|
|
8da23daae3 | ||
|
|
f86022eace | ||
|
|
ee54827f94 | ||
|
|
e908dea702 | ||
|
|
030650a905 | ||
|
|
e01998815e | ||
|
|
07e4a5a23c | ||
|
|
b3a2992a10 | ||
|
|
202fa47f2b | ||
|
|
4888021ba6 | ||
|
|
a0b0b664b6 | ||
|
|
50da5074e7 | ||
|
|
58daa674ef |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||
skip: go.mod,go.sum
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -221,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
cpuProfilingStarted := false
|
||||
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
|
||||
} else {
|
||||
cpuProfilingStarted = true
|
||||
defer func() {
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
} else {
|
||||
cpuProfilingStarted = false
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -302,24 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context(), true)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||
statusOutputString = overview.FullDetailSummary()
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -378,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: logFilePath,
|
||||
LogPath: logFilePath,
|
||||
CPUProfile: nil,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
|
||||
@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
@@ -97,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
@@ -115,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -124,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
|
||||
@@ -190,7 +190,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
169
client/iface/bind/dual_stack_conn.go
Normal file
169
client/iface/bind/dual_stack_conn.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoIPv4Conn = errors.New("no IPv4 connection available")
|
||||
errNoIPv6Conn = errors.New("no IPv6 connection available")
|
||||
errInvalidAddr = errors.New("invalid address type")
|
||||
)
|
||||
|
||||
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
|
||||
// to the appropriate connection based on the destination address.
|
||||
// ReadFrom is not used in the hot path - ICEBind receives packets via
|
||||
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
|
||||
type DualStackPacketConn struct {
|
||||
ipv4Conn net.PacketConn
|
||||
ipv6Conn net.PacketConn
|
||||
|
||||
readFromWarn sync.Once
|
||||
}
|
||||
|
||||
// NewDualStackPacketConn creates a new dual-stack packet connection.
|
||||
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
|
||||
return &DualStackPacketConn{
|
||||
ipv4Conn: ipv4Conn,
|
||||
ipv6Conn: ipv6Conn,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads from the available connection (preferring IPv4).
|
||||
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
|
||||
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
|
||||
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
|
||||
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
|
||||
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
|
||||
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
d.readFromWarn.Do(func() {
|
||||
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
|
||||
})
|
||||
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.ReadFrom(b)
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.ReadFrom(b)
|
||||
}
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
|
||||
// WriteTo writes to the appropriate connection based on the address type.
|
||||
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp",
|
||||
Addr: addr,
|
||||
Err: errInvalidAddr,
|
||||
}
|
||||
}
|
||||
|
||||
if udpAddr.IP.To4() == nil {
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.WriteTo(b, addr)
|
||||
}
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp6",
|
||||
Addr: addr,
|
||||
Err: errNoIPv6Conn,
|
||||
}
|
||||
}
|
||||
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.WriteTo(b, addr)
|
||||
}
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp4",
|
||||
Addr: addr,
|
||||
Err: errNoIPv4Conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes both connections.
|
||||
func (d *DualStackPacketConn) Close() error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the IPv4 connection if available,
|
||||
// otherwise the IPv6 connection.
|
||||
func (d *DualStackPacketConn) LocalAddr() net.Addr {
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.LocalAddr()
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline sets the deadline for both connections.
|
||||
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline for both connections.
|
||||
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline for both connections.
|
||||
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
|
||||
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
|
||||
payload = make([]byte, 1200)
|
||||
)
|
||||
|
||||
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
|
||||
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = conn.WriteTo(payload, ipv4Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
|
||||
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
|
||||
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(nil, conn)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
|
||||
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
|
||||
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn6.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn4, conn6)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
|
||||
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
|
||||
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn6.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn4, conn6)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
|
||||
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
|
||||
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn6.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn4, conn6)
|
||||
addrs := []net.Addr{ipv4Addr, ipv6Addr}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, addrs[i&1])
|
||||
}
|
||||
}
|
||||
191
client/iface/bind/dual_stack_conn_test.go
Normal file
191
client/iface/bind/dual_stack_conn_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
|
||||
ipv4Conn := &mockPacketConn{network: "udp4"}
|
||||
ipv6Conn := &mockPacketConn{network: "udp6"}
|
||||
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr *net.UDPAddr
|
||||
wantSocket string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||
wantSocket: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||
wantSocket: "udp6",
|
||||
},
|
||||
{
|
||||
name: "IPv4-mapped IPv6 goes to IPv4",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
|
||||
wantSocket: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv4 loopback",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
||||
wantSocket: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
|
||||
wantSocket: "udp6",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ipv4Conn.writeCount = 0
|
||||
ipv6Conn.writeCount = 0
|
||||
|
||||
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4, n)
|
||||
|
||||
if tt.wantSocket == "udp4" {
|
||||
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
|
||||
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
|
||||
} else {
|
||||
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
|
||||
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
|
||||
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
|
||||
|
||||
// IPv4 works
|
||||
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||
require.NoError(t, err)
|
||||
|
||||
// IPv6 fails
|
||||
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no IPv6 connection")
|
||||
}
|
||||
|
||||
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
|
||||
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
|
||||
|
||||
// IPv6 works
|
||||
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||
require.NoError(t, err)
|
||||
|
||||
// IPv4 fails
|
||||
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no IPv4 connection")
|
||||
}
|
||||
|
||||
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
|
||||
// only reads from one socket (IPv4 preferred). This is fine because the actual
|
||||
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
|
||||
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
|
||||
ipv4Conn := &mockPacketConn{
|
||||
network: "udp4",
|
||||
readData: []byte("from ipv4"),
|
||||
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||
}
|
||||
ipv6Conn := &mockPacketConn{
|
||||
network: "udp6",
|
||||
readData: []byte("from ipv6"),
|
||||
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||
}
|
||||
|
||||
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, addr, err := dualStack.ReadFrom(buf)
|
||||
|
||||
require.NoError(t, err)
|
||||
// reads from IPv4 (preferred) - this is expected behavior
|
||||
assert.Equal(t, "from ipv4", string(buf[:n]))
|
||||
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
|
||||
}
|
||||
|
||||
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
|
||||
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
|
||||
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ipv4 net.PacketConn
|
||||
ipv6 net.PacketConn
|
||||
wantAddr net.Addr
|
||||
}{
|
||||
{
|
||||
name: "both available returns IPv4",
|
||||
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||
wantAddr: ipv4Addr,
|
||||
},
|
||||
{
|
||||
name: "IPv4 only",
|
||||
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||
ipv6: nil,
|
||||
wantAddr: ipv4Addr,
|
||||
},
|
||||
{
|
||||
name: "IPv6 only",
|
||||
ipv4: nil,
|
||||
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||
wantAddr: ipv6Addr,
|
||||
},
|
||||
{
|
||||
name: "neither returns nil",
|
||||
ipv4: nil,
|
||||
ipv6: nil,
|
||||
wantAddr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
|
||||
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mock
|
||||
|
||||
type mockPacketConn struct {
|
||||
network string
|
||||
writeCount int
|
||||
readData []byte
|
||||
readAddr net.Addr
|
||||
localAddr net.Addr
|
||||
}
|
||||
|
||||
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
if m.readData != nil {
|
||||
return copy(b, m.readData), m.readAddr, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
m.writeCount++
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *mockPacketConn) Close() error { return nil }
|
||||
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
|
||||
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
@@ -28,22 +27,7 @@ type receiverCreator struct {
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
||||
}
|
||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
buf := bufs[0]
|
||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = size
|
||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
||||
eps[0] = stdEp
|
||||
return 1, nil
|
||||
}
|
||||
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -73,6 +57,8 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
ipv4Conn *net.UDPConn
|
||||
ipv6Conn *net.UDPConn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
@@ -118,6 +104,12 @@ func (s *ICEBind) Close() error {
|
||||
|
||||
close(s.closedChan)
|
||||
|
||||
s.muUDPMux.Lock()
|
||||
s.ipv4Conn = nil
|
||||
s.ipv6Conn = nil
|
||||
s.udpMux = nil
|
||||
s.muUDPMux.Unlock()
|
||||
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
@@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(conn),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
)
|
||||
// Detect IPv4 vs IPv6 from connection's local address
|
||||
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
|
||||
s.ipv4Conn = conn
|
||||
} else {
|
||||
s.ipv6Conn = conn
|
||||
}
|
||||
s.createOrUpdateMux()
|
||||
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
@@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
//nolint:staticcheck
|
||||
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
@@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
}
|
||||
}
|
||||
|
||||
// createOrUpdateMux creates or updates the UDP mux with the available connections.
|
||||
// Must be called with muUDPMux held.
|
||||
func (s *ICEBind) createOrUpdateMux() {
|
||||
var muxConn net.PacketConn
|
||||
|
||||
switch {
|
||||
case s.ipv4Conn != nil && s.ipv6Conn != nil:
|
||||
muxConn = NewDualStackPacketConn(
|
||||
nbnet.WrapPacketConn(s.ipv4Conn),
|
||||
nbnet.WrapPacketConn(s.ipv6Conn),
|
||||
)
|
||||
case s.ipv4Conn != nil:
|
||||
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
|
||||
case s.ipv6Conn != nil:
|
||||
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// Don't close the old mux - it doesn't own the underlying connections.
|
||||
// The sockets are managed by WireGuard's StdNetBind, not by us.
|
||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: muxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
for i := range buffers {
|
||||
if !stun.IsMessage(buffers[i]) {
|
||||
@@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
|
||||
return true, err
|
||||
}
|
||||
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet")
|
||||
s.muUDPMux.Lock()
|
||||
mux := s.udpMux
|
||||
s.muUDPMux.Unlock()
|
||||
|
||||
if mux != nil {
|
||||
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet: %v", muxErr)
|
||||
}
|
||||
}
|
||||
|
||||
buffers[i] = []byte{}
|
||||
|
||||
324
client/iface/bind/ice_bind_test.go
Normal file
324
client/iface/bind/ice_bind_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
|
||||
iceBind := setupICEBind(t)
|
||||
|
||||
ipv4Conn, ipv6Conn := createDualStackConns(t)
|
||||
defer ipv4Conn.Close()
|
||||
defer ipv6Conn.Close()
|
||||
|
||||
rc := receiverCreator{iceBind}
|
||||
pool := createMsgPool()
|
||||
|
||||
// Simulate wireguard-go calling CreateReceiverFn for IPv4
|
||||
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
|
||||
require.NotNil(t, ipv4RecvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
|
||||
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
|
||||
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
// Simulate wireguard-go calling CreateReceiverFn for IPv6
|
||||
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
|
||||
require.NotNil(t, ipv6RecvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
|
||||
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
|
||||
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
mux, err := iceBind.GetICEMux()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mux)
|
||||
}
|
||||
|
||||
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
|
||||
iceBind := setupICEBind(t)
|
||||
|
||||
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer ipv4Conn.Close()
|
||||
|
||||
rc := receiverCreator{iceBind}
|
||||
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
|
||||
require.NotNil(t, recvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.NotNil(t, iceBind.ipv4Conn)
|
||||
assert.Nil(t, iceBind.ipv6Conn)
|
||||
assert.NotNil(t, iceBind.udpMux)
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
mux, err := iceBind.GetICEMux()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mux)
|
||||
}
|
||||
|
||||
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
|
||||
iceBind := setupICEBind(t)
|
||||
|
||||
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer ipv6Conn.Close()
|
||||
|
||||
rc := receiverCreator{iceBind}
|
||||
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
|
||||
require.NotNil(t, recvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.Nil(t, iceBind.ipv4Conn)
|
||||
assert.NotNil(t, iceBind.ipv6Conn)
|
||||
assert.NotNil(t, iceBind.udpMux)
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
mux, err := iceBind.GetICEMux()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mux)
|
||||
}
|
||||
|
||||
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
|
||||
// with peers on different address families through the same DualStackPacketConn.
|
||||
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
|
||||
// two "remote peers" listening on different address families
|
||||
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Peer.Close()
|
||||
|
||||
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||
if err != nil {
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer ipv6Peer.Close()
|
||||
|
||||
// our local dual-stack connection
|
||||
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Local.Close()
|
||||
|
||||
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||
defer ipv6Local.Close()
|
||||
|
||||
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||
|
||||
// send to both peers
|
||||
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify IPv4 peer got its packet from the IPv4 socket
|
||||
buf := make([]byte, 100)
|
||||
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, addr, err := ipv4Peer.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "to-ipv4", string(buf[:n]))
|
||||
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||
|
||||
// verify IPv6 peer got its packet from the IPv6 socket
|
||||
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, addr, err = ipv6Peer.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "to-ipv6", string(buf[:n]))
|
||||
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||
}
|
||||
|
||||
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
|
||||
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
|
||||
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
|
||||
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Peer.Close()
|
||||
|
||||
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||
if err != nil {
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer ipv6Peer.Close()
|
||||
|
||||
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Local.Close()
|
||||
|
||||
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||
defer ipv6Local.Close()
|
||||
|
||||
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||
|
||||
const packetsPerFamily = 500
|
||||
|
||||
ipv4Received := make(chan string, packetsPerFamily)
|
||||
ipv6Received := make(chan string, packetsPerFamily)
|
||||
|
||||
startGate := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 100)
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
n, _, err := ipv4Peer.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ipv4Received <- string(buf[:n])
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 100)
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
n, _, err := ipv6Peer.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ipv6Received <- string(buf[:n])
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startGate
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startGate
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
|
||||
}
|
||||
}()
|
||||
|
||||
close(startGate)
|
||||
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
_ = ipv4Peer.SetReadDeadline(time.Now())
|
||||
_ = ipv6Peer.SetReadDeadline(time.Now())
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
close(ipv4Received)
|
||||
close(ipv6Received)
|
||||
|
||||
ipv4Count := 0
|
||||
for pkt := range ipv4Received {
|
||||
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
|
||||
ipv4Count++
|
||||
}
|
||||
|
||||
ipv6Count := 0
|
||||
for pkt := range ipv6Received {
|
||||
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
addr string
|
||||
wantIPv4 bool
|
||||
}{
|
||||
{"IPv4 any", "udp4", "0.0.0.0:0", true},
|
||||
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
|
||||
{"IPv6 any", "udp6", "[::]:0", false},
|
||||
{"IPv6 loopback", "udp6", "[::1]:0", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.ListenUDP(tt.network, addr)
|
||||
if err != nil {
|
||||
t.Skipf("%s not available: %v", tt.network, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
isIPv4 := localAddr.IP.To4() != nil
|
||||
assert.Equal(t, tt.wantIPv4, isIPv4)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
func setupICEBind(t *testing.T) *ICEBind {
|
||||
t.Helper()
|
||||
transportNet, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
|
||||
address := wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
return NewICEBind(transportNet, nil, address, 1280)
|
||||
}
|
||||
|
||||
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||
t.Helper()
|
||||
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
require.NoError(t, err)
|
||||
|
||||
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
ipv4Conn.Close()
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
return ipv4Conn, ipv6Conn
|
||||
}
|
||||
|
||||
func createMsgPool() *sync.Pool {
|
||||
return &sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv6.Message, 1)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, 0, 40)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
|
||||
t.Helper()
|
||||
udpAddr, err := net.ResolveUDPAddr(network, addr)
|
||||
require.NoError(t, err)
|
||||
conn, err := net.ListenUDP(network, udpAddr)
|
||||
require.NoError(t, err)
|
||||
return conn
|
||||
}
|
||||
@@ -3,8 +3,22 @@ package configurer
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
|
||||
// This is a shared helper used by both kernel and userspace configurers.
|
||||
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
|
||||
return wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: peerKey,
|
||||
PresharedKey: &psk,
|
||||
UpdateOnly: updateOnly,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||
ipNets := make([]net.IPNet, len(prefixes))
|
||||
for i, prefix := range prefixes {
|
||||
|
||||
@@ -15,8 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
var zeroKey wgtypes.Key
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
}
|
||||
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPresharedKey sets the preshared key for a peer.
|
||||
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
|
||||
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
|
||||
return c.configure(cfg)
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
TxBytes: p.TransmitBytes,
|
||||
RxBytes: p.ReceiveBytes,
|
||||
LastHandshake: p.LastHandshakeTime,
|
||||
PresharedKey: p.PresharedKey != zeroKey,
|
||||
PresharedKey: [32]byte(p.PresharedKey),
|
||||
}
|
||||
if p.Endpoint != nil {
|
||||
peer.Endpoint = *p.Endpoint
|
||||
|
||||
@@ -22,17 +22,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
)
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
// SetPresharedKey sets the preshared key for a peer.
|
||||
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
|
||||
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
|
||||
return c.device.IpcSet(toWgUserspaceString(cfg))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
hexKey := hex.EncodeToString(p.PublicKey[:])
|
||||
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
|
||||
|
||||
if p.Remove {
|
||||
sb.WriteString("remove=true\n")
|
||||
}
|
||||
|
||||
if p.UpdateOnly {
|
||||
sb.WriteString("update_only=true\n")
|
||||
}
|
||||
|
||||
if p.PresharedKey != nil {
|
||||
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
||||
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
||||
}
|
||||
|
||||
if p.Remove {
|
||||
sb.WriteString("remove=true")
|
||||
}
|
||||
|
||||
if p.ReplaceAllowedIPs {
|
||||
sb.WriteString("replace_allowed_ips=true\n")
|
||||
}
|
||||
|
||||
for _, aip := range p.AllowedIPs {
|
||||
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
||||
}
|
||||
|
||||
if p.Endpoint != nil {
|
||||
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
|
||||
}
|
||||
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
if p.PersistentKeepaliveInterval != nil {
|
||||
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
|
||||
}
|
||||
|
||||
if p.ReplaceAllowedIPs {
|
||||
sb.WriteString("replace_allowed_ips=true\n")
|
||||
}
|
||||
|
||||
for _, aip := range p.AllowedIPs {
|
||||
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
continue
|
||||
}
|
||||
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||
currentPeer.PresharedKey = true
|
||||
if pskKey, err := hexToWireguardKey(val); err == nil {
|
||||
currentPeer.PresharedKey = [32]byte(pskKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ type Peer struct {
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
LastHandshake time.Time
|
||||
PresharedKey bool
|
||||
PresharedKey [32]byte
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
|
||||
@@ -17,6 +17,7 @@ type WGConfigurer interface {
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||
Close()
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
|
||||
@@ -297,6 +297,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||
return w.configurer.FullStats()
|
||||
}
|
||||
|
||||
// SetPresharedKey sets or updates the preshared key for a peer.
|
||||
// If updateOnly is true, only updates existing peer; if false, creates or updates.
|
||||
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
|
||||
}
|
||||
|
||||
func (w *WGIface) waitUntilRemoved() error {
|
||||
maxWaitTime := 5 * time.Second
|
||||
timeout := time.NewTimer(maxWaitTime)
|
||||
|
||||
@@ -117,16 +117,29 @@ func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert endpoint address: %v", err)
|
||||
} else {
|
||||
p.wgCurrentUsed = ep
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, errors.New("nil address")
|
||||
}
|
||||
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
|
||||
@@ -94,7 +94,9 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
if endpoint != nil && endpoint.IP != nil {
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
@@ -59,7 +59,6 @@ func NewConnectClient(
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
@@ -71,8 +70,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@@ -93,7 +92,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -111,10 +110,10 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -284,7 +283,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -472,7 +471,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
@@ -507,7 +506,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
|
||||
ProfileConfig: config,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
|
||||
@@ -28,8 +28,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -57,6 +59,7 @@ block.prof: Block profiling information.
|
||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||
allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
cpu.prof: CPU profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
|
||||
|
||||
@@ -223,10 +226,11 @@ type BundleGenerator struct {
|
||||
internalConfig *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logFile string
|
||||
logPath string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
|
||||
anonymize bool
|
||||
clientStatus string
|
||||
includeSystemInfo bool
|
||||
logFileCount uint32
|
||||
|
||||
@@ -235,7 +239,6 @@ type BundleGenerator struct {
|
||||
|
||||
type BundleConfig struct {
|
||||
Anonymize bool
|
||||
ClientStatus string
|
||||
IncludeSystemInfo bool
|
||||
LogFileCount uint32
|
||||
}
|
||||
@@ -244,7 +247,9 @@ type GeneratorDependencies struct {
|
||||
InternalConfig *profilemanager.Config
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogFile string
|
||||
LogPath string
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -260,10 +265,11 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
internalConfig: deps.InternalConfig,
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logFile: deps.LogFile,
|
||||
logPath: deps.LogPath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
clientStatus: cfg.ClientStatus,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
logFileCount: logFileCount,
|
||||
}
|
||||
@@ -309,13 +315,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
return fmt.Errorf("add status: %w", err)
|
||||
}
|
||||
|
||||
if g.statusRecorder != nil {
|
||||
status := g.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(g.anonymizer, &status)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
|
||||
if err := g.addConfig(); err != nil {
|
||||
log.Errorf("failed to add config to debug bundle: %v", err)
|
||||
}
|
||||
@@ -332,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addCPUProfile(); err != nil {
|
||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
@@ -352,7 +355,7 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
@@ -401,11 +404,30 @@ func (g *BundleGenerator) addReadme() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStatus() error {
|
||||
if status := g.clientStatus; status != "" {
|
||||
statusReader := strings.NewReader(status)
|
||||
if g.statusRecorder != nil {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
if g.refreshStatus != nil {
|
||||
g.refreshStatus()
|
||||
}
|
||||
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
||||
statusOutput := overview.FullDetailSummary()
|
||||
|
||||
statusReader := strings.NewReader(statusOutput)
|
||||
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
||||
return fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
seedFromStatus(g.anonymizer, &fullStatus)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -535,6 +557,19 @@ func (g *BundleGenerator) addProf() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCPUProfile() error {
|
||||
if len(g.cpuProfile) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(g.cpuProfile)
|
||||
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
|
||||
return fmt.Errorf("add CPU profile to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
@@ -710,14 +745,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addLogfile() error {
|
||||
if g.logFile == "" {
|
||||
if g.logPath == "" {
|
||||
log.Debugf("skipping empty log file in debug bundle")
|
||||
return nil
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(g.logFile)
|
||||
logDir := filepath.Dir(g.logPath)
|
||||
|
||||
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
|
||||
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
|
||||
101
client/internal/debug/upload.go
Normal file
101
client/internal/debug/upload.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||
|
||||
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||
response, err := getUploadURL(ctx, url, managementURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = upload(ctx, filePath, response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Key, nil
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||
fileData, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
|
||||
defer fileData.Close()
|
||||
|
||||
stat, err := fileData.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() > maxBundleUploadSize {
|
||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create PUT request: %w", err)
|
||||
}
|
||||
|
||||
req.ContentLength = stat.Size()
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
putResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload failed: %v", err)
|
||||
}
|
||||
defer putResp.Body.Close()
|
||||
|
||||
if putResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(putResp.Body)
|
||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||
id := getURLHash(managementURL)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GET request: %w", err)
|
||||
}
|
||||
|
||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
urlBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
var response types.GetURLResponse
|
||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getURLHash(url string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
|
||||
fileContent := []byte("test file content")
|
||||
err := os.WriteFile(file, fileContent, 0640)
|
||||
require.NoError(t, err)
|
||||
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
require.NoError(t, err)
|
||||
id := getURLHash(testURL)
|
||||
require.Contains(t, key, id+"/")
|
||||
@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||
if peer.PresharedKey {
|
||||
if peer.PresharedKey != [32]byte{} {
|
||||
sb.WriteString(" preshared key: (hidden)\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
logger.Debug("received local resolver request with no question")
|
||||
@@ -120,7 +123,7 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in
|
||||
}
|
||||
|
||||
// No records found, but domain exists with different record types (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
@@ -164,11 +167,15 @@ func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dn
|
||||
}
|
||||
|
||||
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, exists := d.domains[domainName]
|
||||
if !exists && supportsWildcard(qType) {
|
||||
testWild := transformDomainToWildcard(string(domainName))
|
||||
_, exists = d.domains[domain.Domain(testWild)]
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
||||
@@ -195,6 +202,16 @@ type lookupResult struct {
|
||||
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
usingWildcard := false
|
||||
wildQuestion := transformToWildcard(question)
|
||||
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
|
||||
// If the domain exists with any record type, return NODATA instead of wildcard match.
|
||||
if !found && supportsWildcard(question.Qtype) {
|
||||
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
|
||||
records, found = d.records[wildQuestion]
|
||||
usingWildcard = found
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
d.mu.RUnlock()
|
||||
@@ -216,18 +233,53 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(recordsCopy) > 1 {
|
||||
d.mu.Lock()
|
||||
records = d.records[question]
|
||||
q := question
|
||||
if usingWildcard {
|
||||
q = wildQuestion
|
||||
}
|
||||
records = d.records[q]
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records[question] = records
|
||||
d.records[q] = records
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
if usingWildcard {
|
||||
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
|
||||
}
|
||||
|
||||
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
func transformToWildcard(question dns.Question) dns.Question {
|
||||
wildQuestion := question
|
||||
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
|
||||
return wildQuestion
|
||||
}
|
||||
|
||||
func transformDomainToWildcard(domain string) string {
|
||||
s := strings.Split(domain, ".")
|
||||
s[0] = "*"
|
||||
return strings.Join(s, ".")
|
||||
}
|
||||
|
||||
func supportsWildcard(queryType uint16) bool {
|
||||
return queryType != dns.TypeNS && queryType != dns.TypeSOA
|
||||
}
|
||||
|
||||
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
|
||||
records := make([]dns.RR, len(wildRecords))
|
||||
for i, record := range wildRecords {
|
||||
copiedRecord := dns.Copy(record)
|
||||
copiedRecord.Header().Name = originalName
|
||||
records[i] = copiedRecord
|
||||
}
|
||||
|
||||
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||
// the final resolved record of the requested type. This is required for musl libc
|
||||
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||
@@ -237,6 +289,13 @@ func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Questio
|
||||
|
||||
for range maxDepth {
|
||||
cnameRecords := d.getRecords(cnameQuestion)
|
||||
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
|
||||
wildQuestion := transformToWildcard(cnameQuestion)
|
||||
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
|
||||
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
|
||||
}
|
||||
}
|
||||
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
@@ -303,7 +362,7 @@ func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targ
|
||||
}
|
||||
|
||||
// domain exists locally but not this record type (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(targetName)) {
|
||||
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -71,6 +71,11 @@ type upstreamResolverBase struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
@@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
u.prepareRequest(r)
|
||||
|
||||
@@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if u.tryUpstreamServers(w, r, logger) {
|
||||
return
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
if !ok {
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
@@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if u.queryUpstream(w, r, upstream, timeout, logger) {
|
||||
return true
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
}
|
||||
}
|
||||
return false
|
||||
return false, failures
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
@@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
|
||||
return false
|
||||
return u.handleUpstreamError(err, upstream, startTime)
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
return false
|
||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
|
||||
return
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
|
||||
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
|
||||
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
||||
timeoutMsg += " " + peerInfo
|
||||
reason += " " + peerInfo
|
||||
}
|
||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||
logger.Warn(timeoutMsg)
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
@@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||
totalUpstreams := len(u.upstreamServers)
|
||||
failedCount := len(failures)
|
||||
failureSummary := formatFailures(failures)
|
||||
|
||||
if succeeded {
|
||||
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||
} else {
|
||||
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func formatFailures(failures []upstreamFailure) string {
|
||||
parts := make([]string, 0, len(failures))
|
||||
for _, f := range failures {
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
@@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
responses map[string]mockUpstreamResponse
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolver{
|
||||
err: dns.ErrTime,
|
||||
@@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
successAnswer := "192.0.2.100"
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstream1 mockUpstreamResponse
|
||||
upstream2 mockUpstreamResponse
|
||||
expectedRcode int
|
||||
expectAnswer bool
|
||||
expectTrySecond bool
|
||||
}{
|
||||
{
|
||||
name: "success on first upstream",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: false,
|
||||
},
|
||||
{
|
||||
name: "SERVFAIL from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "REFUSED from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "NXDOMAIN from first should NOT try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: false,
|
||||
},
|
||||
{
|
||||
name: "timeout from first should try second",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "no response from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: nil},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "both upstreams return SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "both upstreams timeout",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first SERVFAIL then timeout",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first timeout then SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first REFUSED then SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var queriedUpstreams []string
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream1.String(): tc.upstream1,
|
||||
upstream2.String(): tc.upstream2,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
trackingClient := &trackingMockClient{
|
||||
inner: mockClient,
|
||||
queriedUpstreams: &queriedUpstreams,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: trackingClient,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
|
||||
|
||||
if tc.expectAnswer {
|
||||
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
|
||||
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||
}
|
||||
|
||||
if tc.expectTrySecond {
|
||||
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
|
||||
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
|
||||
} else {
|
||||
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
|
||||
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type trackingMockClient struct {
|
||||
inner *mockUpstreamResolverPerServer
|
||||
queriedUpstreams *[]string
|
||||
}
|
||||
|
||||
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
|
||||
return t.inner.exchange(ctx, upstream, r)
|
||||
}
|
||||
|
||||
func buildMockResponse(rcode int, answer string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.Response = true
|
||||
m.Rcode = rcode
|
||||
|
||||
if rcode == dns.RcodeSuccess && answer != "" {
|
||||
m.Answer = []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "example.com.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: net.ParseIP(answer),
|
||||
},
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
upstream := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamServers: []netip.AddrPort{upstream},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
failures []upstreamFailure
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty slice",
|
||||
failures: []upstreamFailure{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single failure",
|
||||
failures: []upstreamFailure{
|
||||
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||
},
|
||||
expected: "8.8.8.8:53=SERVFAIL",
|
||||
},
|
||||
{
|
||||
name: "multiple failures",
|
||||
failures: []upstreamFailure{
|
||||
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
|
||||
},
|
||||
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := formatFailures(tc.failures)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
@@ -42,12 +43,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
@@ -132,6 +135,11 @@ type EngineConfig struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// for debug bundle generation
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -195,7 +203,8 @@ type Engine struct {
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
@@ -211,6 +220,9 @@ type Engine struct {
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -224,7 +236,18 @@ type localIpUpdater interface {
|
||||
}
|
||||
|
||||
// NewEngine creates a new Connection Engine with probes attached
|
||||
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
stateManager *statemanager.Manager,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
@@ -244,6 +267,7 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -312,6 +336,8 @@ func (e *Engine) Stop() error {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
|
||||
// stop flow manager after wg interface is gone
|
||||
@@ -479,6 +505,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
}
|
||||
|
||||
// if inbound conns are blocked there is no need to create the ACL manager
|
||||
if e.firewall != nil && !e.config.BlockInbound {
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
@@ -500,6 +531,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
@@ -828,9 +860,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// Read the storage-enabled flag under the syncRespMux too.
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// Store sync response if persistence is enabled
|
||||
if e.persistSyncResponse {
|
||||
if enabled {
|
||||
e.syncRespMux.Lock()
|
||||
e.latestSyncResponse = update
|
||||
e.syncRespMux.Unlock()
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
|
||||
@@ -960,6 +1001,80 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
defer e.jobExecutorWG.Done()
|
||||
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
|
||||
resp := mgmProto.JobResponse{
|
||||
ID: msg.ID,
|
||||
Status: mgmProto.JobStatus_failed,
|
||||
}
|
||||
switch params := msg.WorkloadParameters.(type) {
|
||||
case *mgmProto.JobRequest_Bundle:
|
||||
bundleResult, err := e.handleBundle(params.Bundle)
|
||||
if err != nil {
|
||||
log.Errorf("handling bundle: %v", err)
|
||||
resp.Reason = []byte(err.Error())
|
||||
return &resp
|
||||
}
|
||||
resp.Status = mgmProto.JobStatus_succeeded
|
||||
resp.WorkloadResults = bundleResult
|
||||
return &resp
|
||||
default:
|
||||
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
|
||||
return &resp
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
log.Info("stopped receiving jobs from Management Service")
|
||||
}()
|
||||
log.Info("connecting to Management Service jobs stream")
|
||||
}
|
||||
|
||||
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
|
||||
log.Infof("handle remote debug bundle request: %s", params.String())
|
||||
syncResponse, err := e.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
|
||||
bundleDeps := debug.GeneratorDependencies{
|
||||
InternalConfig: e.config.ProfileConfig,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
}
|
||||
|
||||
bundleJobParams := debug.BundleConfig{
|
||||
Anonymize: params.Anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
LogFileCount: uint32(params.LogFileCount),
|
||||
}
|
||||
|
||||
waitFor := time.Duration(params.BundleForTime) * time.Minute
|
||||
|
||||
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := &mgmProto.JobResponse_Bundle{
|
||||
Bundle: &mgmProto.BundleResult{
|
||||
UploadKey: uploadKey,
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
@@ -1405,6 +1520,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
if e.rpManager != nil {
|
||||
peerConn.SetOnConnected(e.rpManager.OnConnected)
|
||||
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
|
||||
peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized)
|
||||
}
|
||||
|
||||
return peerConn, nil
|
||||
@@ -1714,7 +1830,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1728,23 +1844,8 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
turns := slices.Clone(e.TURNs)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
stats, err := e.wgInterface.GetStats()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get wireguard stats: %v", err)
|
||||
e.syncMsgMux.Unlock()
|
||||
return false
|
||||
}
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
wgStats, ok := stats[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
@@ -1848,8 +1949,8 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.syncRespMux.Lock()
|
||||
defer e.syncRespMux.Unlock()
|
||||
|
||||
if enabled == e.persistSyncResponse {
|
||||
return
|
||||
@@ -1864,20 +1965,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
|
||||
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
||||
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
latest := e.latestSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
if !e.persistSyncResponse {
|
||||
if !enabled {
|
||||
return nil, errors.New("sync response persistence is disabled")
|
||||
}
|
||||
|
||||
if e.latestSyncResponse == nil {
|
||||
if latest == nil {
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
|
||||
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
|
||||
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to clone sync response")
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
@@ -213,6 +214,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", util.LogConsole)
|
||||
code := m.Run()
|
||||
@@ -1599,6 +1604,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||
|
||||
@@ -1622,7 +1628,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -1631,7 +1637,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -42,4 +42,5 @@ type wgIfaceBase interface {
|
||||
GetNet() *netstack.Net
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||
}
|
||||
|
||||
@@ -88,8 +88,9 @@ type Conn struct {
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
rosenpassInitializedPresharedKeyValidator func(peerKey string) bool
|
||||
|
||||
statusRelay *worker.AtomicWorkerStatus
|
||||
statusICE *worker.AtomicWorkerStatus
|
||||
@@ -98,7 +99,10 @@ type Conn struct {
|
||||
|
||||
workerICE *WorkerICE
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
wgWatcher *WGWatcher
|
||||
wgWatcherWg sync.WaitGroup
|
||||
wgWatcherCancel context.CancelFunc
|
||||
|
||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||
rosenpassRemoteKey []byte
|
||||
@@ -126,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
@@ -137,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -162,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
@@ -231,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.Log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
if conn.wgWatcherCancel != nil {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
|
||||
@@ -289,6 +297,13 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
||||
conn.onDisconnected = handler
|
||||
}
|
||||
|
||||
// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over
|
||||
// PSK management for a peer. When this returns true, presharedKey() returns nil
|
||||
// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK.
|
||||
func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) {
|
||||
conn.rosenpassInitializedPresharedKeyValidator = handler
|
||||
}
|
||||
|
||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
|
||||
conn.dumpState.RemoteOffer()
|
||||
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||
@@ -366,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
ep = directEp
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
}
|
||||
@@ -390,6 +402,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -423,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
@@ -444,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
}
|
||||
conn.statusICE.SetDisconnected()
|
||||
|
||||
conn.disableWgWatcherIfNeeded()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
Relayed: conn.isRelayed(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
|
||||
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -500,11 +509,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
@@ -519,7 +524,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
conn.handleRelayDisconnectedLocked()
|
||||
}
|
||||
|
||||
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
|
||||
func (conn *Conn) handleRelayDisconnectedLocked() {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -545,6 +554,8 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
}
|
||||
conn.statusRelay.SetDisconnected()
|
||||
|
||||
conn.disableWgWatcherIfNeeded()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
@@ -563,6 +574,28 @@ func (conn *Conn) onGuardEvent() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) onWGDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
|
||||
|
||||
// Close the active connection based on current priority
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay:
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.handleRelayDisconnectedLocked()
|
||||
case conntype.ICEP2P, conntype.ICETurn:
|
||||
conn.workerICE.Close()
|
||||
default:
|
||||
conn.Log.Debugf("No active connection to close on WG timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -689,6 +722,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded() {
|
||||
if !conn.wgWatcher.IsEnabled() {
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) disableWgWatcherIfNeeded() {
|
||||
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
|
||||
conn.wgWatcherCancel()
|
||||
conn.wgWatcherCancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
@@ -759,10 +811,24 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
// If Rosenpass has already set a PSK for this peer, return nil to prevent
|
||||
// UpdatePeer from overwriting the Rosenpass-managed key.
|
||||
if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to
|
||||
// Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum
|
||||
// key is cryptographically bound to the original secret.
|
||||
if conn.config.WgConfig.PreSharedKey != nil {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
// Fallback to deterministic key if no NetBird PSK is configured
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
return nil
|
||||
}
|
||||
|
||||
return determKey
|
||||
|
||||
@@ -284,3 +284,27 @@ func TestConn_presharedKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_presharedKey_RosenpassManaged(t *testing.T) {
|
||||
conn := Conn{
|
||||
config: ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")},
|
||||
},
|
||||
}
|
||||
|
||||
// When Rosenpass has already initialized the PSK for this peer,
|
||||
// presharedKey must return nil to avoid UpdatePeer overwriting it.
|
||||
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true }
|
||||
if k := conn.presharedKey([]byte("remote")); k != nil {
|
||||
t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k)
|
||||
}
|
||||
|
||||
// When Rosenpass hasn't taken over yet, presharedKey should provide
|
||||
// a non-nil initial key (deterministic or from NetBird PSK).
|
||||
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false }
|
||||
if k := conn.presharedKey([]byte("remote")); k == nil {
|
||||
t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1145,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
return d.wgIface.FullStats()
|
||||
}
|
||||
|
||||
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
|
||||
// and updates the cached peer states. This ensures accurate handshake times and
|
||||
// transfer statistics in status reports without running full health probes.
|
||||
func (d *Status) RefreshWireGuardStats() error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if d.wgIface == nil {
|
||||
return nil // silently skip if interface not set
|
||||
}
|
||||
|
||||
stats, err := d.wgIface.FullStats()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get wireguard stats: %w", err)
|
||||
}
|
||||
|
||||
// Update each peer's WireGuard statistics
|
||||
for _, peerStats := range stats.Peers {
|
||||
peerState, ok := d.peers[peerStats.PublicKey]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
peerState.LastWireguardHandshake = peerStats.LastHandshake
|
||||
peerState.BytesRx = peerStats.RxBytes
|
||||
peerState.BytesTx = peerStats.TxBytes
|
||||
d.peers[peerStats.PublicKey] = peerState
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type EventQueue struct {
|
||||
maxSize int
|
||||
events []*proto.SystemEvent
|
||||
|
||||
@@ -30,10 +30,8 @@ type WGWatcher struct {
|
||||
peerKey string
|
||||
stateDump *stateDump
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
enabledTime time.Time
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
}
|
||||
|
||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
w.enabledTime = time.Now()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
w.ctxLock.Unlock()
|
||||
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
|
||||
w.muEnabled.Lock()
|
||||
if w.enabled {
|
||||
w.muEnabled.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||
w.ctx = ctx
|
||||
w.ctxCancel = ctxCancel
|
||||
w.ctxLock.Unlock()
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
enabledTime := time.Now()
|
||||
w.enabled = true
|
||||
w.muEnabled.Unlock()
|
||||
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
|
||||
|
||||
w.muEnabled.Lock()
|
||||
w.enabled = false
|
||||
w.muEnabled.Unlock()
|
||||
}
|
||||
|
||||
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||
func (w *WGWatcher) DisableWgWatcher() {
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctxCancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("disable WireGuard watcher")
|
||||
|
||||
w.ctxCancel()
|
||||
w.ctxCancel = nil
|
||||
// IsEnabled returns true if the WireGuard watcher is currently enabled
|
||||
func (w *WGWatcher) IsEnabled() bool {
|
||||
w.muEnabled.RLock()
|
||||
defer w.muEnabled.RUnlock()
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
defer ctxCancel()
|
||||
|
||||
lastHandshake := initialHandshake
|
||||
|
||||
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
||||
return
|
||||
}
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := handshake.Sub(w.enabledTime).Seconds()
|
||||
elapsed := calcElapsed(enabledTime, *handshake)
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
}
|
||||
|
||||
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
||||
|
||||
// the current know handshake did not change
|
||||
if handshake.Equal(lastHandshake) {
|
||||
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// in case if the machine is suspended, the handshake time will be in the past
|
||||
if handshake.Add(checkPeriod).Before(time.Now()) {
|
||||
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// error handling for handshake time in the future
|
||||
if handshake.After(time.Now()) {
|
||||
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
|
||||
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
|
||||
}
|
||||
return wgState.LastHandshake, nil
|
||||
}
|
||||
|
||||
// calcElapsed calculates elapsed time since watcher was enabled.
|
||||
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
|
||||
func calcElapsed(enabledTime, handshake time.Time) float64 {
|
||||
elapsed := handshake.Sub(enabledTime).Seconds()
|
||||
if elapsed < 0 {
|
||||
elapsed = 0
|
||||
}
|
||||
return elapsed
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
watcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
watcher.EnableWgWatcher(ctx, func() {})
|
||||
}()
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Re-enable with a new context
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
|
||||
go watcher.EnableWgWatcher(ctx, func() {})
|
||||
time.Sleep(1 * time.Second)
|
||||
watcher.DisableWgWatcher()
|
||||
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
watcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
|
||||
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
|
||||
Relayed: isRelayed(pair),
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
}
|
||||
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
// wait local endpoint configuration
|
||||
time.Sleep(time.Second)
|
||||
addrString := pair.Remote.Address()
|
||||
parsed, err := netip.ParseAddr(addrString)
|
||||
if (err == nil) && (parsed.Is6()) {
|
||||
addrString = fmt.Sprintf("[%s]", addrString)
|
||||
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
|
||||
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
|
||||
if err != nil {
|
||||
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
||||
return
|
||||
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||
sessionID := w.SessionID()
|
||||
stats := agent.GetCandidatePairsStats()
|
||||
localCandidates, _ := agent.GetLocalCandidates()
|
||||
remoteCandidates, _ := agent.GetRemoteCandidates()
|
||||
|
||||
localMap := make(map[string]ice.Candidate)
|
||||
for _, c := range localCandidates {
|
||||
localMap[c.ID()] = c
|
||||
}
|
||||
remoteMap := make(map[string]ice.Candidate)
|
||||
for _, c := range remoteCandidates {
|
||||
remoteMap[c.ID()] = c
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.State == ice.CandidatePairStateSucceeded {
|
||||
local, lok := localMap[stat.LocalCandidateID]
|
||||
remote, rok := remoteMap[stat.RemoteCandidateID]
|
||||
if !lok || !rok {
|
||||
continue
|
||||
}
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||
sessionID,
|
||||
local.NetworkType(), local.Type(), local.Address(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||
stat.CurrentRoundTripTime*1000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
|
||||
return func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
switch state {
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
w.logSuccessfulPaths(agent)
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
|
||||
@@ -30,11 +30,9 @@ type WorkerRelay struct {
|
||||
relayLock sync.Mutex
|
||||
|
||||
relaySupportedOnRemotePeer atomic.Bool
|
||||
|
||||
wgWatcher *WGWatcher
|
||||
}
|
||||
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
|
||||
r := &WorkerRelay{
|
||||
peerCtx: ctx,
|
||||
log: log,
|
||||
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
|
||||
config: config,
|
||||
conn: conn,
|
||||
relayManager: relayManager,
|
||||
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
||||
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) DisableWgWatcher() {
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||
return w.relayManager.RelayInstanceAddress()
|
||||
}
|
||||
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onWGDisconnected() {
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
|
||||
w.conn.onRelayDisconnected()
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||
if !w.relayManager.HasRelayAddress() {
|
||||
return false
|
||||
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
go w.conn.onRelayDisconnected()
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLog = slog.LevelInfo
|
||||
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
|
||||
)
|
||||
|
||||
func hashRosenpassKey(key []byte) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write(key)
|
||||
@@ -34,6 +39,7 @@ type Manager struct {
|
||||
server *rp.Server
|
||||
lock sync.Mutex
|
||||
port int
|
||||
wgIface PresharedKeySetter
|
||||
}
|
||||
|
||||
// NewManager creates a new Rosenpass manager
|
||||
@@ -44,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
||||
}
|
||||
|
||||
rpKeyHash := hashRosenpassKey(public)
|
||||
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||
}
|
||||
|
||||
@@ -100,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
|
||||
|
||||
func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
Level: getLogLevel(),
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
|
||||
cfg := rp.Config{Logger: logger}
|
||||
@@ -109,7 +115,13 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
cfg.SecretKey = m.ssk
|
||||
|
||||
cfg.Peers = []rp.PeerConfig{}
|
||||
m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName)
|
||||
|
||||
m.lock.Lock()
|
||||
m.rpWgHandler = NewNetbirdHandler()
|
||||
if m.wgIface != nil {
|
||||
m.rpWgHandler.SetInterface(m.wgIface)
|
||||
}
|
||||
m.lock.Unlock()
|
||||
|
||||
cfg.Handlers = []rp.Handler{m.rpWgHandler}
|
||||
|
||||
@@ -126,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getLogLevel() slog.Level {
|
||||
level, ok := os.LookupEnv(defaultLogLevelVar)
|
||||
if !ok {
|
||||
return defaultLog
|
||||
}
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "info":
|
||||
return slog.LevelInfo
|
||||
case "warn":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
default:
|
||||
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
|
||||
return defaultLog
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) OnDisconnected(peerKey string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
@@ -172,6 +204,20 @@ func (m *Manager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetInterface sets the WireGuard interface for the rosenpass handler.
|
||||
// This can be called before or after Run() - the interface will be stored
|
||||
// and passed to the handler when it's created or updated immediately if
|
||||
// already running.
|
||||
func (m *Manager) SetInterface(iface PresharedKeySetter) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
m.wgIface = iface
|
||||
if m.rpWgHandler != nil {
|
||||
m.rpWgHandler.SetInterface(iface)
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
|
||||
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
|
||||
m.lock.Lock()
|
||||
@@ -192,6 +238,20 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [
|
||||
}
|
||||
}
|
||||
|
||||
// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake
|
||||
// and set a PSK for the given WireGuard peer.
|
||||
func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
peerID, ok := m.rpPeerIDs[wireGuardPubKey]
|
||||
if !ok || peerID == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return m.rpWgHandler.IsPeerInitialized(*peerID)
|
||||
}
|
||||
|
||||
func findRandomAvailableUDPPort() (int, error) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,46 +1,50 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
rp "cunicu.li/go-rosenpass"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers.
|
||||
// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface.
|
||||
type PresharedKeySetter interface {
|
||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||
}
|
||||
|
||||
type wireGuardPeer struct {
|
||||
Interface string
|
||||
PublicKey rp.Key
|
||||
}
|
||||
|
||||
type NetbirdHandler struct {
|
||||
ifaceName string
|
||||
client *wgctrl.Client
|
||||
peers map[rp.PeerID]wireGuardPeer
|
||||
presharedKey [32]byte
|
||||
mu sync.Mutex
|
||||
iface PresharedKeySetter
|
||||
peers map[rp.PeerID]wireGuardPeer
|
||||
initializedPeers map[rp.PeerID]bool
|
||||
}
|
||||
|
||||
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
|
||||
hdlr = &NetbirdHandler{
|
||||
ifaceName: wgIfaceName,
|
||||
peers: map[rp.PeerID]wireGuardPeer{},
|
||||
func NewNetbirdHandler() *NetbirdHandler {
|
||||
return &NetbirdHandler{
|
||||
peers: map[rp.PeerID]wireGuardPeer{},
|
||||
initializedPeers: map[rp.PeerID]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
if preSharedKey != nil {
|
||||
hdlr.presharedKey = *preSharedKey
|
||||
}
|
||||
|
||||
if hdlr.client, err = wgctrl.New(); err != nil {
|
||||
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
|
||||
}
|
||||
|
||||
return hdlr, nil
|
||||
// SetInterface sets the WireGuard interface for the handler.
|
||||
// This must be called after the WireGuard interface is created.
|
||||
func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.iface = iface
|
||||
}
|
||||
|
||||
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.peers[pid] = wireGuardPeer{
|
||||
Interface: intf,
|
||||
PublicKey: pk,
|
||||
@@ -48,79 +52,61 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
|
||||
}
|
||||
|
||||
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.peers, pid)
|
||||
delete(h.initializedPeers, pid)
|
||||
}
|
||||
|
||||
// IsPeerInitialized returns true if Rosenpass has completed a handshake
|
||||
// and set a PSK for this peer.
|
||||
func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.initializedPeers[pid]
|
||||
}
|
||||
|
||||
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
|
||||
log.Debug("Handshake complete")
|
||||
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
||||
}
|
||||
|
||||
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
|
||||
key, _ := rp.GeneratePresharedKey()
|
||||
log.Debug("Handshake expired")
|
||||
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
||||
}
|
||||
|
||||
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
|
||||
h.mu.Lock()
|
||||
iface := h.iface
|
||||
wg, ok := h.peers[pid]
|
||||
isInitialized := h.initializedPeers[pid]
|
||||
h.mu.Unlock()
|
||||
|
||||
if iface == nil {
|
||||
log.Warn("rosenpass: interface not set, cannot update preshared key")
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.client.Device(h.ifaceName)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get WireGuard device: %v", err)
|
||||
peerKey := wgtypes.Key(wg.PublicKey).String()
|
||||
pskKey := wgtypes.Key(psk)
|
||||
|
||||
// Use updateOnly=true for later rotations (peer already has Rosenpass PSK)
|
||||
// Use updateOnly=false for first rotation (peer has original/empty PSK)
|
||||
if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil {
|
||||
log.Errorf("Failed to apply rosenpass key: %v", err)
|
||||
return
|
||||
}
|
||||
config := []wgtypes.PeerConfig{
|
||||
{
|
||||
UpdateOnly: true,
|
||||
PublicKey: wgtypes.Key(wg.PublicKey),
|
||||
PresharedKey: (*wgtypes.Key)(&psk),
|
||||
},
|
||||
}
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
|
||||
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
|
||||
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
|
||||
config = []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: wgtypes.Key(wg.PublicKey),
|
||||
PresharedKey: (*wgtypes.Key)(&psk),
|
||||
Endpoint: peer.Endpoint,
|
||||
AllowedIPs: peer.AllowedIPs,
|
||||
},
|
||||
}
|
||||
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
Remove: true,
|
||||
PublicKey: wgtypes.Key(wg.PublicKey),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Debug("Failed to remove peer")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Mark peer as isInitialized after the successful first rotation
|
||||
if !isInitialized {
|
||||
h.mu.Lock()
|
||||
if _, exists := h.peers[pid]; exists {
|
||||
h.initializedPeers[pid] = true
|
||||
}
|
||||
}
|
||||
|
||||
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
|
||||
Peers: config,
|
||||
}); err != nil {
|
||||
log.Errorf("Failed to apply rosenpass key: %v", err)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func publicKeyEmpty(key wgtypes.Key) bool {
|
||||
for _, b := range key {
|
||||
if b != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
76
client/jobexec/executor.go
Normal file
76
client/jobexec/executor.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package jobexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
|
||||
)
|
||||
|
||||
var (
|
||||
ErrJobNotImplemented = errors.New("job not implemented")
|
||||
)
|
||||
|
||||
type Executor struct {
|
||||
}
|
||||
|
||||
func NewExecutor() *Executor {
|
||||
return &Executor{}
|
||||
}
|
||||
|
||||
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
|
||||
if waitForDuration > MaxBundleWaitTime {
|
||||
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
|
||||
waitForDuration = MaxBundleWaitTime
|
||||
}
|
||||
|
||||
if waitForDuration > 0 {
|
||||
if err := waitFor(ctx, waitForDuration); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("execute debug bundle generation")
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upload debug bundle: %v", err)
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle has been generated successfully")
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func waitFor(ctx context.Context, duration time.Duration) error {
|
||||
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
log.Infof("wait cancelled: %v", ctx.Err())
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v3.21.12
|
||||
// protoc v6.32.1
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
@@ -2757,7 +2757,6 @@ func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule {
|
||||
type DebugBundleRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
|
||||
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
|
||||
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
||||
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
|
||||
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
|
||||
@@ -2802,13 +2801,6 @@ func (x *DebugBundleRequest) GetAnonymize() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) GetStatus() string {
|
||||
if x != nil {
|
||||
return x.Status
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) GetSystemInfo() bool {
|
||||
if x != nil {
|
||||
return x.SystemInfo
|
||||
@@ -5372,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// StartCPUProfileRequest for starting CPU profiling
|
||||
type StartCPUProfileRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartCPUProfileRequest) Reset() {
|
||||
*x = StartCPUProfileRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartCPUProfileRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartCPUProfileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||
}
|
||||
|
||||
// StartCPUProfileResponse confirms CPU profiling has started
|
||||
type StartCPUProfileResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartCPUProfileResponse) Reset() {
|
||||
*x = StartCPUProfileResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartCPUProfileResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartCPUProfileResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||
}
|
||||
|
||||
// StopCPUProfileRequest for stopping CPU profiling
|
||||
type StopCPUProfileRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopCPUProfileRequest) Reset() {
|
||||
*x = StopCPUProfileRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopCPUProfileRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopCPUProfileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{81}
|
||||
}
|
||||
|
||||
// StopCPUProfileResponse confirms CPU profiling has stopped
|
||||
type StopCPUProfileResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopCPUProfileResponse) Reset() {
|
||||
*x = StopCPUProfileResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopCPUProfileResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopCPUProfileResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{82}
|
||||
}
|
||||
|
||||
type InstallerResultRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
@@ -5380,7 +5520,7 @@ type InstallerResultRequest struct {
|
||||
|
||||
func (x *InstallerResultRequest) Reset() {
|
||||
*x = InstallerResultRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5392,7 +5532,7 @@ func (x *InstallerResultRequest) String() string {
|
||||
func (*InstallerResultRequest) ProtoMessage() {}
|
||||
|
||||
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5405,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
|
||||
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{83}
|
||||
}
|
||||
|
||||
type InstallerResultResponse struct {
|
||||
@@ -5418,7 +5558,7 @@ type InstallerResultResponse struct {
|
||||
|
||||
func (x *InstallerResultResponse) Reset() {
|
||||
*x = InstallerResultResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5430,7 +5570,7 @@ func (x *InstallerResultResponse) String() string {
|
||||
func (*InstallerResultResponse) ProtoMessage() {}
|
||||
|
||||
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5443,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
|
||||
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{84}
|
||||
}
|
||||
|
||||
func (x *InstallerResultResponse) GetSuccess() bool {
|
||||
@@ -5470,7 +5610,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5482,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5773,10 +5913,9 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
|
||||
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
|
||||
"\x17ForwardingRulesResponse\x12,\n" +
|
||||
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
|
||||
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
|
||||
"\x12DebugBundleRequest\x12\x1c\n" +
|
||||
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
|
||||
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
|
||||
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
|
||||
"\n" +
|
||||
"systemInfo\x18\x03 \x01(\bR\n" +
|
||||
"systemInfo\x12\x1c\n" +
|
||||
@@ -6003,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
||||
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
||||
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
|
||||
"\x16StartCPUProfileRequest\"\x19\n" +
|
||||
"\x17StartCPUProfileResponse\"\x17\n" +
|
||||
"\x15StopCPUProfileRequest\"\x18\n" +
|
||||
"\x16StopCPUProfileResponse\"\x18\n" +
|
||||
"\x16InstallerResultRequest\"O\n" +
|
||||
"\x17InstallerResultResponse\x12\x18\n" +
|
||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||
@@ -6015,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x04WARN\x10\x04\x12\b\n" +
|
||||
"\x04INFO\x10\x05\x12\t\n" +
|
||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||
"\x05TRACE\x10\a2\xb4\x13\n" +
|
||||
"\x05TRACE\x10\a2\xdd\x14\n" +
|
||||
"\rDaemonService\x126\n" +
|
||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||
@@ -6050,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
|
||||
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
||||
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
||||
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
|
||||
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" +
|
||||
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
|
||||
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
|
||||
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
||||
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||
|
||||
@@ -6067,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
||||
@@ -6152,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{
|
||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
||||
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
|
||||
nil, // 85: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
|
||||
nil, // 87: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
|
||||
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
|
||||
nil, // 89: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
|
||||
nil, // 91: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
@@ -6177,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
@@ -6189,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
@@ -6226,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
67, // [67:100] is the sub-list for method output_type
|
||||
34, // [34:67] is the sub-list for method input_type
|
||||
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
69, // [69:104] is the sub-list for method output_type
|
||||
34, // [34:69] is the sub-list for method input_type
|
||||
34, // [34:34] is the sub-list for extension type_name
|
||||
34, // [34:34] is the sub-list for extension extendee
|
||||
0, // [0:34] is the sub-list for field type_name
|
||||
@@ -6292,7 +6445,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 4,
|
||||
NumMessages: 84,
|
||||
NumMessages: 88,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -94,6 +94,12 @@ service DaemonService {
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon
|
||||
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
|
||||
|
||||
// StopCPUProfile stops CPU profiling in the daemon
|
||||
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
||||
|
||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||
|
||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||
@@ -455,7 +461,6 @@ message ForwardingRulesResponse {
|
||||
// DebugBundler
|
||||
message DebugBundleRequest {
|
||||
bool anonymize = 1;
|
||||
string status = 2;
|
||||
bool systemInfo = 3;
|
||||
string uploadURL = 4;
|
||||
uint32 logFileCount = 5;
|
||||
@@ -777,6 +782,18 @@ message WaitJWTTokenResponse {
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
// StartCPUProfileRequest for starting CPU profiling
|
||||
message StartCPUProfileRequest {}
|
||||
|
||||
// StartCPUProfileResponse confirms CPU profiling has started
|
||||
message StartCPUProfileResponse {}
|
||||
|
||||
// StopCPUProfileRequest for stopping CPU profiling
|
||||
message StopCPUProfileRequest {}
|
||||
|
||||
// StopCPUProfileResponse confirms CPU profiling has stopped
|
||||
message StopCPUProfileResponse {}
|
||||
|
||||
message InstallerResultRequest {
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,10 @@ type DaemonServiceClient interface {
|
||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||
// StartCPUProfile starts CPU profiling in the daemon
|
||||
StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error)
|
||||
// StopCPUProfile stops CPU profiling in the daemon
|
||||
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
|
||||
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||
}
|
||||
@@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) {
|
||||
out := new(StartCPUProfileResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) {
|
||||
out := new(StopCPUProfileResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
|
||||
out := new(OSLifecycleResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
|
||||
@@ -458,6 +480,10 @@ type DaemonServiceServer interface {
|
||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||
// StartCPUProfile starts CPU profiling in the daemon
|
||||
StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error)
|
||||
// StopCPUProfile stops CPU profiling in the daemon
|
||||
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
|
||||
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
@@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
|
||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
||||
}
|
||||
@@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StartCPUProfileRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).StartCPUProfile(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/StartCPUProfile",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StopCPUProfileRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).StopCPUProfile(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/StopCPUProfile",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(OSLifecycleRequest)
|
||||
if err := dec(in); err != nil {
|
||||
@@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "WaitJWTToken",
|
||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StartCPUProfile",
|
||||
Handler: _DaemonService_StartCPUProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StopCPUProfile",
|
||||
Handler: _DaemonService_StopCPUProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "NotifyOSLifecycle",
|
||||
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
||||
|
||||
@@ -3,25 +3,19 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
@@ -32,16 +26,37 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
log.Warnf("failed to get latest sync response: %v", err)
|
||||
}
|
||||
|
||||
var cpuProfileData []byte
|
||||
if s.cpuProfileBuf != nil && !s.cpuProfiling {
|
||||
cpuProfileData = s.cpuProfileBuf.Bytes()
|
||||
defer func() {
|
||||
s.cpuProfileBuf = nil
|
||||
}()
|
||||
}
|
||||
|
||||
// Prepare refresh callback for health probes
|
||||
var refreshStatus func()
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
debug.GeneratorDependencies{
|
||||
InternalConfig: s.config,
|
||||
StatusRecorder: s.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: s.logFile,
|
||||
LogPath: s.logFile,
|
||||
CPUProfile: cpuProfileData,
|
||||
RefreshStatus: refreshStatus,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
Anonymize: req.GetAnonymize(),
|
||||
ClientStatus: req.GetStatus(),
|
||||
IncludeSystemInfo: req.GetSystemInfo(),
|
||||
LogFileCount: req.GetLogFileCount(),
|
||||
},
|
||||
@@ -55,7 +70,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
if req.GetUploadURL() == "" {
|
||||
return &proto.DebugBundleResponse{Path: path}, nil
|
||||
}
|
||||
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
||||
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
||||
@@ -66,92 +81,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
||||
}
|
||||
|
||||
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||
response, err := getUploadURL(ctx, url, managementURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = upload(ctx, filePath, response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Key, nil
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||
fileData, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
|
||||
defer fileData.Close()
|
||||
|
||||
stat, err := fileData.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() > maxBundleUploadSize {
|
||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create PUT request: %w", err)
|
||||
}
|
||||
|
||||
req.ContentLength = stat.Size()
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
putResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload failed: %v", err)
|
||||
}
|
||||
defer putResp.Body.Close()
|
||||
|
||||
if putResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(putResp.Body)
|
||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||
id := getURLHash(managementURL)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GET request: %w", err)
|
||||
}
|
||||
|
||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
urlBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
var response types.GetURLResponse
|
||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getURLHash(url string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||
}
|
||||
|
||||
// GetLogLevel gets the current logging level for the server.
|
||||
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
||||
s.mutex.Lock()
|
||||
@@ -204,3 +133,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
|
||||
return cClient.GetLatestSyncResponse()
|
||||
}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon.
|
||||
func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.cpuProfiling {
|
||||
return nil, fmt.Errorf("CPU profiling already in progress")
|
||||
}
|
||||
|
||||
s.cpuProfileBuf = &bytes.Buffer{}
|
||||
s.cpuProfiling = true
|
||||
if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil {
|
||||
s.cpuProfileBuf = nil
|
||||
s.cpuProfiling = false
|
||||
return nil, fmt.Errorf("start CPU profile: %w", err)
|
||||
}
|
||||
|
||||
log.Info("CPU profiling started")
|
||||
return &proto.StartCPUProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// StopCPUProfile stops CPU profiling in the daemon.
|
||||
func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.cpuProfiling {
|
||||
return nil, fmt.Errorf("CPU profiling not in progress")
|
||||
}
|
||||
|
||||
pprof.StopCPUProfile()
|
||||
s.cpuProfiling = false
|
||||
|
||||
if s.cpuProfileBuf != nil {
|
||||
log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len())
|
||||
}
|
||||
|
||||
return &proto.StopCPUProfileResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -13,9 +14,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
@@ -67,7 +67,7 @@ type Server struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -78,6 +78,9 @@ type Server struct {
|
||||
persistSyncResponse bool
|
||||
isSessionActive atomic.Bool
|
||||
|
||||
cpuProfileBuf *bytes.Buffer
|
||||
cpuProfiling bool
|
||||
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
@@ -793,9 +796,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
// Down engine work in the daemon.
|
||||
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
giveUpChan := s.clientGiveUpChan
|
||||
|
||||
if err := s.cleanupConnection(); err != nil {
|
||||
s.mutex.Unlock()
|
||||
// todo review to update the status in case any type of error
|
||||
log.Errorf("failed to shut down properly: %v", err)
|
||||
return nil, err
|
||||
@@ -804,6 +809,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
|
||||
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
|
||||
// The giveUpChan is closed at the end of connectWithRetryRuns.
|
||||
if giveUpChan != nil {
|
||||
select {
|
||||
case <-giveUpChan:
|
||||
log.Debugf("client goroutine finished successfully")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -1308,6 +1327,10 @@ func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if engine.RunHealthProbes(waitForProbeResult) {
|
||||
s.lastProbe = time.Now()
|
||||
}
|
||||
} else {
|
||||
if err := s.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1521,7 +1544,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
||||
log.Tracef("running client connection")
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
if err := s.connectClient.Run(runningChan); err != nil {
|
||||
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -306,6 +307,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
@@ -317,7 +320,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -326,7 +329,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -11,8 +11,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
|
||||
@@ -116,9 +120,7 @@ type OutputOverview struct {
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
}
|
||||
|
||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
|
||||
managementState := pbFullStatus.GetManagementState()
|
||||
managementOverview := ManagementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
@@ -134,13 +136,13 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
DaemonVersion: daemonVersion,
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
Relays: relayOverview,
|
||||
@@ -489,6 +491,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
if o.NumberOfForwardingRules > 0 {
|
||||
forwardingRulesString = fmt.Sprintf("Forwarding rules: %d\n", o.NumberOfForwardingRules)
|
||||
}
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
goarm := ""
|
||||
@@ -512,7 +519,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Forwarding rules: %d\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
o.DaemonVersion,
|
||||
@@ -529,7 +536,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
networks,
|
||||
o.NumberOfForwardingRules,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
@@ -553,6 +560,94 @@ func (o *OutputOverview) FullDetailSummary() string {
|
||||
)
|
||||
}
|
||||
|
||||
func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||
if err := fullStatus.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||
if err := fullStatus.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fullStatus.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fullStatus.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||
var (
|
||||
peersString = ""
|
||||
|
||||
@@ -238,7 +238,7 @@ var overview = OutputOverview{
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "")
|
||||
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "")
|
||||
|
||||
assert.Equal(t, overview, convertedResult)
|
||||
}
|
||||
@@ -567,7 +567,6 @@ Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Forwarding rules: 0
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
@@ -592,7 +591,6 @@ Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Forwarding rules: 0
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
|
||||
@@ -1033,7 +1033,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mDown.Disable()
|
||||
systray.AddSeparator()
|
||||
|
||||
s.mSettings = systray.AddMenuItem("Settings", settingsMenuDescr)
|
||||
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
@@ -1060,7 +1060,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
}
|
||||
|
||||
s.exitNodeMu.Lock()
|
||||
s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr)
|
||||
s.mExitNode = systray.AddMenuItem("Exit Node", disabledMenuDescr)
|
||||
s.mExitNode.Disable()
|
||||
s.exitNodeMu.Unlock()
|
||||
|
||||
@@ -1261,7 +1261,6 @@ func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||
if s.mSettings != nil {
|
||||
if enabled {
|
||||
s.mSettings.Enable()
|
||||
s.mSettings.SetTooltip(settingsMenuDescr)
|
||||
} else {
|
||||
s.mSettings.Hide()
|
||||
s.mSettings.SetTooltip("Settings are disabled by daemon")
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package main
|
||||
|
||||
const (
|
||||
settingsMenuDescr = "Settings of the application"
|
||||
profilesMenuDescr = "Manage your profiles"
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
@@ -11,7 +9,7 @@ const (
|
||||
notificationsMenuDescr = "Enable notifications"
|
||||
advancedSettingsMenuDescr = "Advanced settings of the application"
|
||||
debugBundleMenuDescr = "Create and open debug information bundle"
|
||||
exitNodeMenuDescr = "Select exit node for routing traffic"
|
||||
disabledMenuDescr = ""
|
||||
networksMenuDescr = "Open the networks management window"
|
||||
latestVersionMenuDescr = "Download latest version"
|
||||
quitMenuDescr = "Quit the client app"
|
||||
|
||||
@@ -18,9 +18,7 @@ import (
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
@@ -291,19 +289,18 @@ func (s *serviceClient) handleRunForDuration(
|
||||
return
|
||||
}
|
||||
|
||||
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
|
||||
if err != nil {
|
||||
defer s.restoreServiceState(conn, initialState)
|
||||
|
||||
if err := s.collectDebugData(conn, initialState, params, progressUI); err != nil {
|
||||
handleError(progressUI, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
|
||||
if err := s.createDebugBundleFromCollection(conn, params, progressUI); err != nil {
|
||||
handleError(progressUI, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.restoreServiceState(conn, initialState)
|
||||
|
||||
progressUI.statusLabel.SetText("Bundle created successfully")
|
||||
}
|
||||
|
||||
@@ -409,6 +406,10 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
||||
log.Warnf("failed to start CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -417,68 +418,37 @@ func (s *serviceClient) collectDebugData(
|
||||
state *debugInitialState,
|
||||
params *debugCollectionParams,
|
||||
progress *progressUI,
|
||||
) (string, error) {
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||
|
||||
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get post-up status: %v", err)
|
||||
}
|
||||
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||
|
||||
wg.Wait()
|
||||
progress.progressBar.Hide()
|
||||
progress.statusLabel.SetText("Collecting debug data...")
|
||||
|
||||
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get pre-down status: %v", err)
|
||||
if _, err := conn.StopCPUProfile(s.ctx, &proto.StopCPUProfileRequest{}); err != nil {
|
||||
log.Warnf("failed to stop CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
time.Now().Format(time.RFC3339), params.duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
|
||||
|
||||
return statusOutput, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the debug bundle with collected data
|
||||
func (s *serviceClient) createDebugBundleFromCollection(
|
||||
conn proto.DaemonServiceClient,
|
||||
params *debugCollectionParams,
|
||||
statusOutput string,
|
||||
progress *progressUI,
|
||||
) error {
|
||||
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: params.anonymize,
|
||||
Status: statusOutput,
|
||||
SystemInfo: params.systemInfo,
|
||||
}
|
||||
|
||||
@@ -581,26 +551,8 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||
}
|
||||
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymize,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfo,
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,8 @@ func (h *eventHandler) handleConnectClick() {
|
||||
func (h *eventHandler) handleDisconnectClick() {
|
||||
h.client.mDown.Disable()
|
||||
|
||||
h.client.exitNodeStates = []exitNodeState{}
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
log.Debugf("cancelling ongoing connect operation")
|
||||
h.client.connectCancel()
|
||||
|
||||
@@ -390,7 +390,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
|
||||
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" {
|
||||
s.mExitNode.Remove()
|
||||
s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr)
|
||||
s.mExitNode = systray.AddMenuItem("Exit Node", disabledMenuDescr)
|
||||
}
|
||||
|
||||
var showDeselectAll bool
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
@@ -350,12 +349,8 @@ func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error)
|
||||
}
|
||||
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
statusResp := &proto.StatusResponse{
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
FullStatus: pbFullStatus,
|
||||
}
|
||||
|
||||
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
|
||||
return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, false, version.NetbirdVersion(), "", nil, nil, nil, "", ""), nil
|
||||
}
|
||||
|
||||
// createStatusMethod creates the status method that returns JSON
|
||||
|
||||
4
go.mod
4
go.mod
@@ -68,8 +68,9 @@ require (
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/oapi-codegen/runtime v1.1.2
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -141,6 +142,7 @@ require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -35,12 +35,15 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
|
||||
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
|
||||
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
|
||||
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
|
||||
@@ -87,6 +90,7 @@ github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE=
|
||||
github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
@@ -320,6 +324,7 @@ github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7X
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
|
||||
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
@@ -401,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
@@ -416,6 +421,8 @@ github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXq
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI=
|
||||
github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc=
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
@@ -522,6 +529,7 @@ github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
|
||||
|
||||
356
idp/dex/connector.go
Normal file
356
idp/dex/connector.go
Normal file
@@ -0,0 +1,356 @@
|
||||
// Package dex provides an embedded Dex OIDC identity provider.
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// ConnectorConfig represents the configuration for an identity provider connector
|
||||
type ConnectorConfig struct {
|
||||
// ID is the unique identifier for the connector
|
||||
ID string
|
||||
// Name is a human-readable name for the connector
|
||||
Name string
|
||||
// Type is the connector type (oidc, google, microsoft)
|
||||
Type string
|
||||
// Issuer is the OIDC issuer URL (for OIDC-based connectors)
|
||||
Issuer string
|
||||
// ClientID is the OAuth2 client ID
|
||||
ClientID string
|
||||
// ClientSecret is the OAuth2 client secret
|
||||
ClientSecret string
|
||||
// RedirectURI is the OAuth2 redirect URI
|
||||
RedirectURI string
|
||||
}
|
||||
|
||||
// CreateConnector creates a new connector in Dex storage.
|
||||
// It maps the connector config to the appropriate Dex connector type and configuration.
|
||||
func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
|
||||
// Fill in the redirect URI if not provided
|
||||
if cfg.RedirectURI == "" {
|
||||
cfg.RedirectURI = p.GetRedirectURI()
|
||||
}
|
||||
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build connector: %w", err)
|
||||
}
|
||||
|
||||
if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
|
||||
return nil, fmt.Errorf("failed to create connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// GetConnector retrieves a connector by ID from Dex storage.
|
||||
func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
|
||||
conn, err := p.storage.GetConnector(ctx, id)
|
||||
if err != nil {
|
||||
if err == storage.ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get connector: %w", err)
|
||||
}
|
||||
|
||||
return p.parseStorageConnector(conn)
|
||||
}
|
||||
|
||||
// ListConnectors returns all connectors from Dex storage (excluding the local connector).
|
||||
func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
|
||||
connectors, err := p.storage.ListConnectors(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list connectors: %w", err)
|
||||
}
|
||||
|
||||
result := make([]*ConnectorConfig, 0, len(connectors))
|
||||
for _, conn := range connectors {
|
||||
// Skip the local password connector
|
||||
if conn.ID == "local" && conn.Type == "local" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, err := p.parseStorageConnector(conn)
|
||||
if err != nil {
|
||||
p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, cfg)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector in Dex storage.
|
||||
// It merges incoming updates with existing values to prevent data loss on partial updates.
|
||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
oldCfg, err := p.parseStorageConnector(old)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
|
||||
}
|
||||
|
||||
mergeConnectorConfig(cfg, oldCfg)
|
||||
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
|
||||
}
|
||||
return storageConn, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeConnectorConfig preserves existing values for empty fields in the update.
|
||||
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
|
||||
if cfg.ClientSecret == "" {
|
||||
cfg.ClientSecret = oldCfg.ClientSecret
|
||||
}
|
||||
if cfg.RedirectURI == "" {
|
||||
cfg.RedirectURI = oldCfg.RedirectURI
|
||||
}
|
||||
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
|
||||
cfg.Issuer = oldCfg.Issuer
|
||||
}
|
||||
if cfg.ClientID == "" {
|
||||
cfg.ClientID = oldCfg.ClientID
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
cfg.Name = oldCfg.Name
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteConnector removes a connector from Dex storage.
|
||||
func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
|
||||
// Prevent deletion of the local connector
|
||||
if id == "local" {
|
||||
return fmt.Errorf("cannot delete the local password connector")
|
||||
}
|
||||
|
||||
if err := p.storage.DeleteConnector(ctx, id); err != nil {
|
||||
return fmt.Errorf("failed to delete connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("connector deleted", "id", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the default redirect URI for connectors.
|
||||
func (p *Provider) GetRedirectURI() string {
|
||||
if p.config == nil {
|
||||
return ""
|
||||
}
|
||||
issuer := strings.TrimSuffix(p.config.Issuer, "/")
|
||||
if !strings.HasSuffix(issuer, "/oauth2") {
|
||||
issuer += "/oauth2"
|
||||
}
|
||||
return issuer + "/callback"
|
||||
}
|
||||
|
||||
// buildStorageConnector creates a storage.Connector from ConnectorConfig.
|
||||
// It handles the type-specific configuration for each connector type.
|
||||
func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
|
||||
redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
|
||||
|
||||
var dexType string
|
||||
var configData []byte
|
||||
var err error
|
||||
|
||||
switch cfg.Type {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
dexType = "oidc"
|
||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||
case "google":
|
||||
dexType = "google"
|
||||
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
|
||||
case "microsoft":
|
||||
dexType = "microsoft"
|
||||
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
|
||||
default:
|
||||
return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return storage.Connector{}, err
|
||||
}
|
||||
|
||||
return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
|
||||
}
|
||||
|
||||
// resolveRedirectURI returns the redirect URI, using a default if not provided
|
||||
func (p *Provider) resolveRedirectURI(redirectURI string) string {
|
||||
if redirectURI != "" || p.config == nil {
|
||||
return redirectURI
|
||||
}
|
||||
issuer := strings.TrimSuffix(p.config.Issuer, "/")
|
||||
if !strings.HasSuffix(issuer, "/oauth2") {
|
||||
issuer += "/oauth2"
|
||||
}
|
||||
return issuer + "/callback"
|
||||
}
|
||||
|
||||
// buildOIDCConnectorConfig creates config for OIDC-based connectors
|
||||
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
oidcConfig := map[string]interface{}{
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"insecureEnableGroups": true,
|
||||
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
|
||||
"insecureSkipEmailVerified": true,
|
||||
}
|
||||
switch cfg.Type {
|
||||
case "zitadel":
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
case "okta":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
|
||||
// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
|
||||
func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
return encodeConnectorConfig(map[string]interface{}{
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
})
|
||||
}
|
||||
|
||||
// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
|
||||
// It infers the original identity provider type from the Dex connector type and ID.
|
||||
func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
|
||||
cfg := &ConnectorConfig{
|
||||
ID: conn.ID,
|
||||
Name: conn.Name,
|
||||
}
|
||||
|
||||
if len(conn.Config) == 0 {
|
||||
cfg.Type = conn.Type
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
var configMap map[string]interface{}
|
||||
if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse connector config: %w", err)
|
||||
}
|
||||
|
||||
// Extract common fields
|
||||
if v, ok := configMap["clientID"].(string); ok {
|
||||
cfg.ClientID = v
|
||||
}
|
||||
if v, ok := configMap["clientSecret"].(string); ok {
|
||||
cfg.ClientSecret = v
|
||||
}
|
||||
if v, ok := configMap["redirectURI"].(string); ok {
|
||||
cfg.RedirectURI = v
|
||||
}
|
||||
if v, ok := configMap["issuer"].(string); ok {
|
||||
cfg.Issuer = v
|
||||
}
|
||||
|
||||
// Infer the original identity provider type from Dex connector type and ID
|
||||
cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// inferIdentityProviderType determines the original identity provider type
|
||||
// based on the Dex connector type, connector ID, and configuration.
|
||||
func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
|
||||
if dexType != "oidc" {
|
||||
return dexType
|
||||
}
|
||||
return inferOIDCProviderType(connectorID)
|
||||
}
|
||||
|
||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||
func inferOIDCProviderType(connectorID string) string {
|
||||
connectorIDLower := strings.ToLower(connectorID)
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||
if strings.Contains(connectorIDLower, provider) {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
return "oidc"
|
||||
}
|
||||
|
||||
// encodeConnectorConfig serializes connector config to JSON bytes.
|
||||
func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
|
||||
return json.Marshal(config)
|
||||
}
|
||||
|
||||
// decodeConnectorConfig deserializes connector config from JSON bytes.
|
||||
func decodeConnectorConfig(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// ensureLocalConnector creates a local (password) connector if it doesn't exist
|
||||
func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
||||
// Check specifically for the local connector
|
||||
_, err := stor.GetConnector(ctx, "local")
|
||||
if err == nil {
|
||||
// Local connector already exists
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
return fmt.Errorf("failed to get local connector: %w", err)
|
||||
}
|
||||
|
||||
// Create a local connector for password authentication
|
||||
localConnector := storage.Connector{
|
||||
ID: "local",
|
||||
Type: "local",
|
||||
Name: "Email",
|
||||
}
|
||||
|
||||
if err := stor.CreateConnector(ctx, localConnector); err != nil {
|
||||
return fmt.Errorf("failed to create local connector: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureStaticConnectors creates or updates static connectors in storage
|
||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||
for _, conn := range connectors {
|
||||
storConn, err := conn.ToStorageConnector()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
|
||||
}
|
||||
_, err = stor.GetConnector(ctx, conn.ID)
|
||||
if err == storage.ErrNotFound {
|
||||
if err := stor.CreateConnector(ctx, storConn); err != nil {
|
||||
return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
|
||||
}
|
||||
if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
old.Name = storConn.Name
|
||||
old.Config = storConn.Config
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -4,7 +4,6 @@ package dex
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -245,34 +244,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureStaticConnectors creates or updates static connectors in storage
|
||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||
for _, conn := range connectors {
|
||||
storConn, err := conn.ToStorageConnector()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
|
||||
}
|
||||
_, err = stor.GetConnector(ctx, conn.ID)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
if err := stor.CreateConnector(ctx, storConn); err != nil {
|
||||
return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
|
||||
}
|
||||
if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
old.Name = storConn.Name
|
||||
old.Config = storConn.Config
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDexConfig creates a server.Config with defaults applied
|
||||
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
|
||||
cfg := yamlConfig.ToServerConfig(stor, logger)
|
||||
@@ -613,294 +584,37 @@ func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) {
|
||||
return p.storage.ListPasswords(ctx)
|
||||
}
|
||||
|
||||
// ensureLocalConnector creates a local (password) connector if none exists
|
||||
func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
||||
connectors, err := stor.ListConnectors(ctx)
|
||||
// UpdateUserPassword updates the password for a user identified by userID.
|
||||
// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID.
|
||||
// It verifies the current password before updating.
|
||||
func (p *Provider) UpdateUserPassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
|
||||
// Get the user by ID to find their email
|
||||
user, err := p.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list connectors: %w", err)
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// If any connector exists, we're good
|
||||
if len(connectors) > 0 {
|
||||
return nil
|
||||
// Verify old password
|
||||
if err := bcrypt.CompareHashAndPassword(user.Hash, []byte(oldPassword)); err != nil {
|
||||
return fmt.Errorf("current password is incorrect")
|
||||
}
|
||||
|
||||
// Create a local connector for password authentication
|
||||
localConnector := storage.Connector{
|
||||
ID: "local",
|
||||
Type: "local",
|
||||
Name: "Email",
|
||||
}
|
||||
|
||||
if err := stor.CreateConnector(ctx, localConnector); err != nil {
|
||||
return fmt.Errorf("failed to create local connector: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectorConfig represents the configuration for an identity provider connector
|
||||
type ConnectorConfig struct {
|
||||
// ID is the unique identifier for the connector
|
||||
ID string
|
||||
// Name is a human-readable name for the connector
|
||||
Name string
|
||||
// Type is the connector type (oidc, google, microsoft)
|
||||
Type string
|
||||
// Issuer is the OIDC issuer URL (for OIDC-based connectors)
|
||||
Issuer string
|
||||
// ClientID is the OAuth2 client ID
|
||||
ClientID string
|
||||
// ClientSecret is the OAuth2 client secret
|
||||
ClientSecret string
|
||||
// RedirectURI is the OAuth2 redirect URI
|
||||
RedirectURI string
|
||||
}
|
||||
|
||||
// CreateConnector creates a new connector in Dex storage.
|
||||
// It maps the connector config to the appropriate Dex connector type and configuration.
|
||||
func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
|
||||
// Fill in the redirect URI if not provided
|
||||
if cfg.RedirectURI == "" {
|
||||
cfg.RedirectURI = p.GetRedirectURI()
|
||||
}
|
||||
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
// Hash the new password
|
||||
newHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build connector: %w", err)
|
||||
return fmt.Errorf("failed to hash new password: %w", err)
|
||||
}
|
||||
|
||||
if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
|
||||
return nil, fmt.Errorf("failed to create connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// GetConnector retrieves a connector by ID from Dex storage.
|
||||
func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
|
||||
conn, err := p.storage.GetConnector(ctx, id)
|
||||
if err != nil {
|
||||
if err == storage.ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get connector: %w", err)
|
||||
}
|
||||
|
||||
return p.parseStorageConnector(conn)
|
||||
}
|
||||
|
||||
// ListConnectors returns all connectors from Dex storage (excluding the local connector).
|
||||
func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
|
||||
connectors, err := p.storage.ListConnectors(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list connectors: %w", err)
|
||||
}
|
||||
|
||||
result := make([]*ConnectorConfig, 0, len(connectors))
|
||||
for _, conn := range connectors {
|
||||
// Skip the local password connector
|
||||
if conn.ID == "local" && conn.Type == "local" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, err := p.parseStorageConnector(conn)
|
||||
if err != nil {
|
||||
p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, cfg)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector in Dex storage.
|
||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build connector: %w", err)
|
||||
}
|
||||
|
||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
return storageConn, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteConnector removes a connector from Dex storage.
|
||||
func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
|
||||
// Prevent deletion of the local connector
|
||||
if id == "local" {
|
||||
return fmt.Errorf("cannot delete the local password connector")
|
||||
}
|
||||
|
||||
if err := p.storage.DeleteConnector(ctx, id); err != nil {
|
||||
return fmt.Errorf("failed to delete connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("connector deleted", "id", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildStorageConnector creates a storage.Connector from ConnectorConfig.
|
||||
// It handles the type-specific configuration for each connector type.
|
||||
func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
|
||||
redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
|
||||
|
||||
var dexType string
|
||||
var configData []byte
|
||||
var err error
|
||||
|
||||
switch cfg.Type {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
dexType = "oidc"
|
||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||
case "google":
|
||||
dexType = "google"
|
||||
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
|
||||
case "microsoft":
|
||||
dexType = "microsoft"
|
||||
configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
|
||||
default:
|
||||
return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return storage.Connector{}, err
|
||||
}
|
||||
|
||||
return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
|
||||
}
|
||||
|
||||
// resolveRedirectURI returns the redirect URI, using a default if not provided
|
||||
func (p *Provider) resolveRedirectURI(redirectURI string) string {
|
||||
if redirectURI != "" || p.config == nil {
|
||||
return redirectURI
|
||||
}
|
||||
issuer := strings.TrimSuffix(p.config.Issuer, "/")
|
||||
if !strings.HasSuffix(issuer, "/oauth2") {
|
||||
issuer += "/oauth2"
|
||||
}
|
||||
return issuer + "/callback"
|
||||
}
|
||||
|
||||
// buildOIDCConnectorConfig creates config for OIDC-based connectors
|
||||
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
oidcConfig := map[string]interface{}{
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"insecureEnableGroups": true,
|
||||
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
|
||||
"insecureSkipEmailVerified": true,
|
||||
}
|
||||
switch cfg.Type {
|
||||
case "zitadel":
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
case "okta":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
|
||||
// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
|
||||
func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
return encodeConnectorConfig(map[string]interface{}{
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
// Update the password in storage
|
||||
err = p.storage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
||||
old.Hash = newHash
|
||||
return old, nil
|
||||
})
|
||||
}
|
||||
|
||||
// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
|
||||
// It infers the original identity provider type from the Dex connector type and ID.
|
||||
func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
|
||||
cfg := &ConnectorConfig{
|
||||
ID: conn.ID,
|
||||
Name: conn.Name,
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
if len(conn.Config) == 0 {
|
||||
cfg.Type = conn.Type
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
var configMap map[string]interface{}
|
||||
if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse connector config: %w", err)
|
||||
}
|
||||
|
||||
// Extract common fields
|
||||
if v, ok := configMap["clientID"].(string); ok {
|
||||
cfg.ClientID = v
|
||||
}
|
||||
if v, ok := configMap["clientSecret"].(string); ok {
|
||||
cfg.ClientSecret = v
|
||||
}
|
||||
if v, ok := configMap["redirectURI"].(string); ok {
|
||||
cfg.RedirectURI = v
|
||||
}
|
||||
if v, ok := configMap["issuer"].(string); ok {
|
||||
cfg.Issuer = v
|
||||
}
|
||||
|
||||
// Infer the original identity provider type from Dex connector type and ID
|
||||
cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// inferIdentityProviderType determines the original identity provider type
|
||||
// based on the Dex connector type, connector ID, and configuration.
|
||||
func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
|
||||
if dexType != "oidc" {
|
||||
return dexType
|
||||
}
|
||||
return inferOIDCProviderType(connectorID)
|
||||
}
|
||||
|
||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||
func inferOIDCProviderType(connectorID string) string {
|
||||
connectorIDLower := strings.ToLower(connectorID)
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||
if strings.Contains(connectorIDLower, provider) {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
return "oidc"
|
||||
}
|
||||
|
||||
// encodeConnectorConfig serializes connector config to JSON bytes.
|
||||
func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
|
||||
return json.Marshal(config)
|
||||
}
|
||||
|
||||
// decodeConnectorConfig deserializes connector config from JSON bytes.
|
||||
func decodeConnectorConfig(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the default redirect URI for connectors.
|
||||
func (p *Provider) GetRedirectURI() string {
|
||||
if p.config == nil {
|
||||
return ""
|
||||
}
|
||||
issuer := strings.TrimSuffix(p.config.Issuer, "/")
|
||||
if !strings.HasSuffix(issuer, "/oauth2") {
|
||||
issuer += "/oauth2"
|
||||
}
|
||||
return issuer + "/callback"
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIssuer returns the OIDC issuer URL.
|
||||
|
||||
@@ -82,16 +82,6 @@ read_nb_domain() {
|
||||
return 0
|
||||
}
|
||||
|
||||
get_turn_external_ip() {
|
||||
TURN_EXTERNAL_IP_CONFIG="#external-ip="
|
||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
|
||||
if [[ "x-$IP" != "x-" ]]; then
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||
fi
|
||||
echo "$TURN_EXTERNAL_IP_CONFIG"
|
||||
return 0
|
||||
}
|
||||
|
||||
read_reverse_proxy_type() {
|
||||
echo "" > /dev/stderr
|
||||
echo "Which reverse proxy will you use?" > /dev/stderr
|
||||
@@ -249,14 +239,17 @@ initialize_default_values() {
|
||||
NETBIRD_PORT=80
|
||||
NETBIRD_HTTP_PROTOCOL="http"
|
||||
NETBIRD_RELAY_PROTO="rel"
|
||||
TURN_USER="self"
|
||||
TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||
NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING")
|
||||
# Note: DataStoreEncryptionKey must keep base64 padding (=) for Go's base64.StdEncoding
|
||||
DATASTORE_ENCRYPTION_KEY=$(openssl rand -base64 32)
|
||||
TURN_MIN_PORT=49152
|
||||
TURN_MAX_PORT=65535
|
||||
TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip)
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
CADDY_IMAGE="caddy"
|
||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||
SIGNAL_IMAGE="netbirdio/signal:latest"
|
||||
RELAY_IMAGE="netbirdio/relay:latest"
|
||||
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
||||
|
||||
# Reverse proxy configuration
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
@@ -320,7 +313,7 @@ check_existing_installation() {
|
||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||
echo "You can use the following commands:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||
echo " rm -f docker-compose.yml Caddyfile dashboard.env turnserver.conf management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||
exit 1
|
||||
fi
|
||||
@@ -363,7 +356,6 @@ generate_configuration_files() {
|
||||
# Common files for all configurations
|
||||
render_dashboard_env > dashboard.env
|
||||
render_management_json > management.json
|
||||
render_turn_server_conf > turnserver.conf
|
||||
render_relay_env > relay.env
|
||||
return 0
|
||||
}
|
||||
@@ -487,34 +479,13 @@ EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_turn_server_conf() {
|
||||
cat <<EOF
|
||||
listening-port=3478
|
||||
$TURN_EXTERNAL_IP_CONFIG
|
||||
tls-listening-port=5349
|
||||
min-port=$TURN_MIN_PORT
|
||||
max-port=$TURN_MAX_PORT
|
||||
fingerprint
|
||||
lt-cred-mech
|
||||
user=$TURN_USER:$TURN_PASSWORD
|
||||
realm=wiretrustee.com
|
||||
cert=/etc/coturn/certs/cert.pem
|
||||
pkey=/etc/coturn/private/privkey.pem
|
||||
log-file=stdout
|
||||
no-software-attribute
|
||||
pidfile="/var/tmp/turnserver.pid"
|
||||
no-cli
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_management_json() {
|
||||
cat <<EOF
|
||||
{
|
||||
"Stuns": [
|
||||
{
|
||||
"Proto": "udp",
|
||||
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
||||
"URI": "stun:$NETBIRD_DOMAIN:$NETBIRD_STUN_PORT"
|
||||
}
|
||||
],
|
||||
"Relay": {
|
||||
@@ -569,6 +540,9 @@ NB_LOG_LEVEL=info
|
||||
NB_LISTEN_ADDRESS=:80
|
||||
NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT
|
||||
NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
NB_ENABLE_STUN=true
|
||||
NB_STUN_LOG_LEVEL=info
|
||||
NB_STUN_PORTS=$NETBIRD_STUN_PORT
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
@@ -578,7 +552,7 @@ render_docker_compose() {
|
||||
services:
|
||||
# Caddy reverse proxy
|
||||
caddy:
|
||||
image: caddy
|
||||
image: $CADDY_IMAGE
|
||||
container_name: netbird-caddy
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -597,7 +571,7 @@ services:
|
||||
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
image: netbirdio/dashboard:latest
|
||||
image: $DASHBOARD_IMAGE
|
||||
container_name: netbird-dashboard
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -611,7 +585,7 @@ services:
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
image: $SIGNAL_IMAGE
|
||||
container_name: netbird-signal
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -621,12 +595,14 @@ services:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Relay
|
||||
# Relay (includes embedded STUN server)
|
||||
relay:
|
||||
image: netbirdio/relay:latest
|
||||
image: $RELAY_IMAGE
|
||||
container_name: netbird-relay
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
ports:
|
||||
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
|
||||
env_file:
|
||||
- ./relay.env
|
||||
logging:
|
||||
@@ -637,7 +613,7 @@ services:
|
||||
|
||||
# Management (includes embedded IdP)
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
image: $MANAGEMENT_IMAGE
|
||||
container_name: netbird-management
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -659,22 +635,6 @@ services:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Coturn, AKA TURN server
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
container_name: netbird-coturn
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
netbird_caddy_data:
|
||||
netbird_management:
|
||||
@@ -702,7 +662,7 @@ render_docker_compose_traefik() {
|
||||
services:
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
image: netbirdio/dashboard:latest
|
||||
image: $DASHBOARD_IMAGE
|
||||
container_name: netbird-dashboard
|
||||
restart: unless-stopped
|
||||
networks: [$network_name]
|
||||
@@ -724,7 +684,7 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-das
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
image: $SIGNAL_IMAGE
|
||||
container_name: netbird-signal
|
||||
restart: unless-stopped
|
||||
networks: [$network_name]
|
||||
@@ -751,12 +711,14 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-sig
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Relay
|
||||
# Relay (includes embedded STUN server)
|
||||
relay:
|
||||
image: netbirdio/relay:latest
|
||||
image: $RELAY_IMAGE
|
||||
container_name: netbird-relay
|
||||
restart: unless-stopped
|
||||
networks: [$network_name]
|
||||
ports:
|
||||
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
|
||||
env_file:
|
||||
- ./relay.env
|
||||
labels:
|
||||
@@ -774,7 +736,7 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-rel
|
||||
|
||||
# Management (includes embedded IdP)
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
image: $MANAGEMENT_IMAGE
|
||||
container_name: netbird-management
|
||||
restart: unless-stopped
|
||||
networks: [$network_name]
|
||||
@@ -827,24 +789,6 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-oau
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Coturn, AKA TURN server
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
container_name: netbird-coturn
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
network_mode: host
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
netbird_management:
|
||||
|
||||
@@ -874,7 +818,7 @@ render_docker_compose_exposed_ports() {
|
||||
services:
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
image: netbirdio/dashboard:latest
|
||||
image: $DASHBOARD_IMAGE
|
||||
container_name: netbird-dashboard
|
||||
restart: unless-stopped
|
||||
networks: ${networks}
|
||||
@@ -890,7 +834,7 @@ services:
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
image: $SIGNAL_IMAGE
|
||||
container_name: netbird-signal
|
||||
restart: unless-stopped
|
||||
networks: ${networks}
|
||||
@@ -903,14 +847,15 @@ services:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Relay
|
||||
# Relay (includes embedded STUN server)
|
||||
relay:
|
||||
image: netbirdio/relay:latest
|
||||
image: $RELAY_IMAGE
|
||||
container_name: netbird-relay
|
||||
restart: unless-stopped
|
||||
networks: ${networks}
|
||||
ports:
|
||||
- '${bind_addr}:${RELAY_HOST_PORT}:80'
|
||||
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
|
||||
env_file:
|
||||
- ./relay.env
|
||||
logging:
|
||||
@@ -921,7 +866,7 @@ services:
|
||||
|
||||
# Management (includes embedded IdP)
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
image: $MANAGEMENT_IMAGE
|
||||
container_name: netbird-management
|
||||
restart: unless-stopped
|
||||
networks: ${networks}
|
||||
@@ -945,22 +890,6 @@ services:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Coturn, AKA TURN server
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
container_name: netbird-coturn
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
netbird_management:
|
||||
|
||||
|
||||
@@ -856,3 +856,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
|
||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer) {
|
||||
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
@@ -36,4 +36,6 @@ type Controller interface {
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
|
||||
TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
// Source: management/internals/controllers/network_map/interface.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
// mockgen -package network_map -destination=management/internals/controllers/network_map/interface_mock.go -source=management/internals/controllers/network_map/interface.go -build_flags=-mod=mod
|
||||
//
|
||||
|
||||
// Package network_map is a generated GoMock package.
|
||||
@@ -211,6 +211,18 @@ func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
|
||||
}
|
||||
|
||||
// TrackEphemeralPeer mocks base method.
|
||||
func (m *MockController) TrackEphemeralPeer(ctx context.Context, arg1 *peer.Peer) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "TrackEphemeralPeer", ctx, arg1)
|
||||
}
|
||||
|
||||
// TrackEphemeralPeer indicates an expected call of TrackEphemeralPeer.
|
||||
func (mr *MockControllerMockRecorder) TrackEphemeralPeer(ctx, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrackEphemeralPeer", reflect.TypeOf((*MockController)(nil).TrackEphemeralPeer), ctx, arg1)
|
||||
}
|
||||
|
||||
// UpdateAccountPeer mocks base method.
|
||||
func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -31,6 +31,7 @@ type Manager interface {
|
||||
SetNetworkMapController(networkMapController network_map.Controller)
|
||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -167,3 +168,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||
}
|
||||
|
||||
@@ -97,6 +97,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeerID mocks base method.
|
||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerID", ctx, peerKey)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerID indicates an expected call of GetPeerID.
|
||||
func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -144,7 +144,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
@@ -24,6 +26,12 @@ func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) JobManager() *job.Manager {
|
||||
return Create(s, func() *job.Manager {
|
||||
return job.NewJobManager(s.Metrics(), s.Store(), s.PeersManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
|
||||
@@ -87,7 +87,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
|
||||
|
||||
func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -195,6 +195,7 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
|
||||
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
|
||||
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -57,6 +59,7 @@ type Server struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
proto.UnimplementedManagementServiceServer
|
||||
jobManager *job.Manager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
@@ -82,6 +85,7 @@ func NewServer(
|
||||
config *nbconfig.Config,
|
||||
accountManager account.Manager,
|
||||
settingsManager settings.Manager,
|
||||
jobManager *job.Manager,
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
authManager auth.Manager,
|
||||
@@ -114,6 +118,7 @@ func NewServer(
|
||||
}
|
||||
|
||||
return &Server{
|
||||
jobManager: jobManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
@@ -169,6 +174,40 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Job(srv proto.ManagementService_JobServer) error {
|
||||
reqStart := time.Now()
|
||||
ctx := srv.Context()
|
||||
|
||||
peerKey, err := s.handleHandshake(ctx, srv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
return err
|
||||
}
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
s.startResponseReceiver(ctx, srv)
|
||||
|
||||
updates := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID)
|
||||
log.WithContext(ctx).Debugf("Job: took %v", time.Since(reqStart))
|
||||
|
||||
return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
@@ -289,6 +328,70 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||
hello, err := srv.Recv()
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "missing hello: %v", err)
|
||||
}
|
||||
|
||||
jobReq := &proto.JobRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, hello, jobReq)
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, err
|
||||
}
|
||||
|
||||
return peerKey, nil
|
||||
}
|
||||
|
||||
func (s *Server) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) {
|
||||
go func() {
|
||||
for {
|
||||
msg, err := srv.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Warnf("recv job response error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
jobResp := &proto.JobResponse{}
|
||||
if _, err := s.parseRequest(ctx, msg, jobResp); err != nil {
|
||||
log.WithContext(ctx).Warnf("invalid job response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.jobManager.HandleResponse(ctx, jobResp, msg.WgPubKey); err != nil {
|
||||
log.WithContext(ctx).Errorf("handle job response failed: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates *job.Channel, srv proto.ManagementService_JobServer) error {
|
||||
// todo figure out better error handling strategy
|
||||
defer s.jobManager.CloseChannel(ctx, accountID, peer.ID)
|
||||
|
||||
for {
|
||||
event, err := updates.Event(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, job.ErrJobChannelClosed) {
|
||||
log.WithContext(ctx).Debugf("jobs channel for peer %s was closed", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err := s.sendJob(ctx, peerKey, event, srv); err != nil {
|
||||
log.WithContext(ctx).Warnf("send job failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
@@ -306,7 +409,6 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
@@ -336,7 +438,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
@@ -348,6 +450,31 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendJob encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Event, srv proto.ManagementService_JobServer) error {
|
||||
wgKey, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get wg key for peer %s: %v", peerKey.String(), err)
|
||||
return status.Errorf(codes.Internal, "failed processing job message")
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, job.Request)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to encrypt job for peer %s: %v", peerKey.String(), err)
|
||||
return status.Errorf(codes.Internal, "failed processing job message")
|
||||
}
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed sending job message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent a job to peer: %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
@@ -690,8 +817,8 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
|
||||
var err error
|
||||
|
||||
var turnToken *Token
|
||||
|
||||
if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
|
||||
turnToken, err = s.secretsManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
cacheStore "github.com/eko/gocache/lib/v4/store"
|
||||
@@ -70,6 +71,7 @@ type DefaultAccountManager struct {
|
||||
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
|
||||
cacheLoading map[string]chan struct{}
|
||||
networkMapController network_map.Controller
|
||||
jobManager *job.Manager
|
||||
idpManager idp.Manager
|
||||
cacheManager *nbcache.AccountUserDataCache
|
||||
externalCacheManager nbcache.UserDataCache
|
||||
@@ -178,6 +180,7 @@ func BuildManager(
|
||||
config *nbconfig.Config,
|
||||
store store.Store,
|
||||
networkMapController network_map.Controller,
|
||||
jobManager *job.Manager,
|
||||
idpManager idp.Manager,
|
||||
singleAccountModeDomain string,
|
||||
eventStore activity.Store,
|
||||
@@ -200,6 +203,7 @@ func BuildManager(
|
||||
config: config,
|
||||
geo: geo,
|
||||
networkMapController: networkMapController,
|
||||
jobManager: jobManager,
|
||||
idpManager: idpManager,
|
||||
ctx: context.Background(),
|
||||
cacheMux: sync.Mutex{},
|
||||
|
||||
@@ -32,6 +32,7 @@ type Manager interface {
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
|
||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||
RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
|
||||
@@ -129,4 +130,7 @@ type Manager interface {
|
||||
CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -3023,13 +3024,14 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -195,6 +195,10 @@ const (
|
||||
DNSRecordUpdated Activity = 100
|
||||
DNSRecordDeleted Activity = 101
|
||||
|
||||
JobCreatedByUser Activity = 102
|
||||
|
||||
UserPasswordChanged Activity = 103
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -319,6 +323,10 @@ var activityMap = map[Activity]Code{
|
||||
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
|
||||
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
|
||||
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
|
||||
|
||||
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
|
||||
|
||||
UserPasswordChanged: {"User password changed", "user.password.change"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||
|
||||
type FieldEncrypt struct {
|
||||
block cipher.Block
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
readableKey := base64.StdEncoding.EncodeToString(key)
|
||||
return readableKey, nil
|
||||
}
|
||||
|
||||
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
|
||||
binKey, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ec := &FieldEncrypt{
|
||||
block: block,
|
||||
gcm: gcm,
|
||||
}
|
||||
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
|
||||
plainText := pkcs5Padding([]byte(payload))
|
||||
cipherText := make([]byte, len(plainText))
|
||||
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
||||
cbc.CryptBlocks(cipherText, plainText)
|
||||
return base64.StdEncoding.EncodeToString(cipherText)
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-GCM
|
||||
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
|
||||
plaintext := []byte(payload)
|
||||
nonceSize := ec.gcm.NonceSize()
|
||||
|
||||
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cbc := cipher.NewCBCDecrypter(ec.block, iv)
|
||||
cbc.CryptBlocks(cipherText, cipherText)
|
||||
payload, err := pkcs5UnPadding(cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(payload), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-GCM
|
||||
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := ec.gcm.NonceSize()
|
||||
if len(cipherText) < nonceSize {
|
||||
return "", errors.New("cipher text too short")
|
||||
}
|
||||
|
||||
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plainText), nil
|
||||
}
|
||||
|
||||
func pkcs5Padding(ciphertext []byte) []byte {
|
||||
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padText...)
|
||||
}
|
||||
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
||||
srcLen := len(src)
|
||||
if srcLen == 0 {
|
||||
return nil, errors.New("input data is empty")
|
||||
}
|
||||
|
||||
paddingLen := int(src[srcLen-1])
|
||||
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
|
||||
return nil, errors.New("invalid padding size")
|
||||
}
|
||||
|
||||
// Verify that all padding bytes are the same
|
||||
for i := 0; i < paddingLen; i++ {
|
||||
if src[srcLen-1-i] != byte(paddingLen) {
|
||||
return nil, errors.New("invalid padding")
|
||||
}
|
||||
}
|
||||
|
||||
return src[:srcLen-paddingLen], nil
|
||||
}
|
||||
@@ -1,310 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted, err := ee.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt data: %s", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
decrypted, err := ee.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decrypt data: %s", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateKeyLegacy(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.LegacyEncrypt(testData)
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
decrypted, err := ee.LegacyDecrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decrypt data: %s", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorruptKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted, err := ee.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt data: %s", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
newKey, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
|
||||
ee, err = NewFieldEncrypt(newKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
res, _ := ee.Decrypt(encrypted)
|
||||
if res == testData {
|
||||
t.Fatalf("incorrect decryption, the result is: %s", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Generate a key for encryption/decryption
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
// Initialize the FieldEncrypt with the generated key
|
||||
ec, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create FieldEncrypt: %v", err)
|
||||
}
|
||||
|
||||
// Test cases
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Empty String",
|
||||
input: "",
|
||||
},
|
||||
{
|
||||
name: "Short String",
|
||||
input: "Hello",
|
||||
},
|
||||
{
|
||||
name: "String with Spaces",
|
||||
input: "Hello, World!",
|
||||
},
|
||||
{
|
||||
name: "Long String",
|
||||
input: "The quick brown fox jumps over the lazy dog.",
|
||||
},
|
||||
{
|
||||
name: "Unicode Characters",
|
||||
input: "こんにちは世界",
|
||||
},
|
||||
{
|
||||
name: "Special Characters",
|
||||
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
},
|
||||
{
|
||||
name: "Numeric String",
|
||||
input: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "Repeated Characters",
|
||||
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
},
|
||||
{
|
||||
name: "Multi-block String",
|
||||
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
|
||||
},
|
||||
{
|
||||
name: "Non-ASCII and ASCII Mix",
|
||||
input: "Hello 世界 123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+" - Legacy", func(t *testing.T) {
|
||||
// Legacy Encryption
|
||||
encryptedLegacy := ec.LegacyEncrypt(tc.input)
|
||||
if encryptedLegacy == "" {
|
||||
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
|
||||
}
|
||||
|
||||
// Legacy Decryption
|
||||
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
|
||||
if err != nil {
|
||||
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
|
||||
// Verify that the decrypted value matches the original input
|
||||
if decryptedLegacy != tc.input {
|
||||
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(tc.name+" - New", func(t *testing.T) {
|
||||
// New Encryption
|
||||
encryptedNew, err := ec.Encrypt(tc.input)
|
||||
if err != nil {
|
||||
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
if encryptedNew == "" {
|
||||
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
|
||||
}
|
||||
|
||||
// New Decryption
|
||||
decryptedNew, err := ec.Decrypt(encryptedNew)
|
||||
if err != nil {
|
||||
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
|
||||
// Verify that the decrypted value matches the original input
|
||||
if decryptedNew != tc.input {
|
||||
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCS5UnPadding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Padding",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Empty Input",
|
||||
input: []byte{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Zero",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Exceeds Block Size",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Exceeds Input Length",
|
||||
input: []byte{5, 5, 5},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Padding Bytes",
|
||||
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Single Byte Padding",
|
||||
input: append([]byte("Hello, World!"), byte(1)),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Invalid Mixed Padding Bytes",
|
||||
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Full Block Padding",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Non-Padding Byte at End",
|
||||
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Padding with Different Text Length",
|
||||
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
|
||||
expected: []byte("Test"),
|
||||
},
|
||||
{
|
||||
name: "Padding Length Equal to Input Length",
|
||||
input: bytes.Repeat([]byte{8}, 8),
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Invalid Padding Length Zero (Again)",
|
||||
input: append([]byte("Test"), byte(0)),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Greater Than Input",
|
||||
input: []byte{10},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Input Length Not Multiple of Block Size",
|
||||
input: append([]byte("Invalid Length"), byte(1)),
|
||||
expected: []byte("Invalid Length"),
|
||||
},
|
||||
{
|
||||
name: "Valid Padding with Non-ASCII Characters",
|
||||
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
|
||||
expected: []byte("こんにちは"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs5UnPadding(tt.input)
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Did not expect error but got: %v", err)
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("Expected output %v, got %v", tt.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
|
||||
func migrate(ctx context.Context, crypt *crypt.FieldEncrypt, db *gorm.DB) error {
|
||||
migrations := getMigrations(ctx, crypt)
|
||||
|
||||
for _, m := range migrations {
|
||||
@@ -26,7 +27,7 @@ func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
|
||||
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
|
||||
func getMigrations(ctx context.Context, crypt *crypt.FieldEncrypt) []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "")
|
||||
@@ -45,7 +46,7 @@ func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
|
||||
|
||||
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using
|
||||
// legacy CBC encryption with a static IV to the new GCM encryption method.
|
||||
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error {
|
||||
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *crypt.FieldEncrypt) error {
|
||||
model := &activity.DeletedUser{}
|
||||
|
||||
if !db.Migrator().HasTable(model) {
|
||||
@@ -80,7 +81,7 @@ func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *F
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error {
|
||||
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *crypt.FieldEncrypt) error {
|
||||
var err error
|
||||
var decryptedEmail, decryptedName string
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,10 +41,10 @@ func setupDatabase(t *testing.T) *gorm.DB {
|
||||
func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
key, err := GenerateKey()
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err, "Failed to generate key")
|
||||
|
||||
crypt, err := NewFieldEncrypt(key)
|
||||
crypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err, "Failed to initialize FieldEncrypt")
|
||||
|
||||
t.Run("empty table, no migration required", func(t *testing.T) {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -45,12 +46,12 @@ type eventWithNames struct {
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
fieldEncrypt *FieldEncrypt
|
||||
fieldEncrypt *crypt.FieldEncrypt
|
||||
}
|
||||
|
||||
// NewSqlStore creates a new Store with an event table if not exists.
|
||||
func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
|
||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
@@ -61,7 +62,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St
|
||||
return nil, fmt.Errorf("initialize database: %w", err)
|
||||
}
|
||||
|
||||
if err = migrate(ctx, crypt, db); err != nil {
|
||||
if err = migrate(ctx, fieldEncrypt, db); err != nil {
|
||||
return nil, fmt.Errorf("events database migration: %w", err)
|
||||
}
|
||||
|
||||
@@ -72,7 +73,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St
|
||||
|
||||
return &Store{
|
||||
db: db,
|
||||
fieldEncrypt: crypt,
|
||||
fieldEncrypt: fieldEncrypt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestNewSqlStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
key, _ := GenerateKey()
|
||||
key, _ := crypt.GenerateKey()
|
||||
store, err := NewSqlStore(context.Background(), dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -221,13 +222,14 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
|
||||
@@ -36,6 +36,9 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
@@ -46,6 +49,99 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
req := &api.JobRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
respBody := make([]*api.JobResponse, 0, len(jobs))
|
||||
for _, job := range jobs {
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
respBody = append(respBody, resp)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
jobID := vars["jobId"]
|
||||
|
||||
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
@@ -521,6 +617,28 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
}
|
||||
}
|
||||
|
||||
func toSingleJobResponse(job *types.Job) (*api.JobResponse, error) {
|
||||
workload, err := job.BuildWorkloadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var failed *string
|
||||
if job.FailedReason != "" {
|
||||
failed = &job.FailedReason
|
||||
}
|
||||
|
||||
return &api.JobResponse{
|
||||
Id: job.ID,
|
||||
CreatedAt: job.CreatedAt,
|
||||
CompletedAt: job.CompletedAt,
|
||||
TriggeredBy: job.TriggeredBy,
|
||||
Status: api.JobResponseStatus(job.Status),
|
||||
FailedReason: failed,
|
||||
Workload: *workload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fqdn(peer *nbpeer.Peer, dnsDomain string) string {
|
||||
fqdn := peer.FQDN(dnsDomain)
|
||||
if fqdn == "" {
|
||||
|
||||
@@ -33,6 +33,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
@@ -410,3 +411,46 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// passwordChangeRequest represents the request body for password change
|
||||
type passwordChangeRequest struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
// changePassword is a PUT request to change user's password.
|
||||
// Only available when embedded IDP is enabled.
|
||||
// Users can only change their own password.
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req passwordChangeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
@@ -856,3 +856,118 @@ func TestRejectUserEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePasswordEndpoint(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestBody string
|
||||
targetUserID string
|
||||
currentUserID string
|
||||
mockError error
|
||||
expectMockNotCalled bool
|
||||
}{
|
||||
{
|
||||
name: "successful password change",
|
||||
expectedStatus: http.StatusOK,
|
||||
requestBody: `{"old_password": "OldPass123!", "new_password": "NewPass456!"}`,
|
||||
targetUserID: existingUserID,
|
||||
currentUserID: existingUserID,
|
||||
mockError: nil,
|
||||
},
|
||||
{
|
||||
name: "missing old password",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
requestBody: `{"new_password": "NewPass456!"}`,
|
||||
targetUserID: existingUserID,
|
||||
currentUserID: existingUserID,
|
||||
mockError: status.Errorf(status.InvalidArgument, "old password is required"),
|
||||
},
|
||||
{
|
||||
name: "missing new password",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
requestBody: `{"old_password": "OldPass123!"}`,
|
||||
targetUserID: existingUserID,
|
||||
currentUserID: existingUserID,
|
||||
mockError: status.Errorf(status.InvalidArgument, "new password is required"),
|
||||
},
|
||||
{
|
||||
name: "wrong old password",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
requestBody: `{"old_password": "WrongPass!", "new_password": "NewPass456!"}`,
|
||||
targetUserID: existingUserID,
|
||||
currentUserID: existingUserID,
|
||||
mockError: status.Errorf(status.InvalidArgument, "invalid password"),
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
requestBody: `{"old_password": "OldPass123!", "new_password": "NewPass456!"}`,
|
||||
targetUserID: existingUserID,
|
||||
currentUserID: existingUserID,
|
||||
mockError: status.Errorf(status.PreconditionFailed, "password change is only available with embedded identity provider"),
|
||||
},
|
||||
{
|
||||
name: "invalid JSON request",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
requestBody: `{invalid json}`,
|
||||
targetUserID: existingUserID,
|
||||
currentUserID: existingUserID,
|
||||
expectMockNotCalled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockCalled := false
|
||||
am := &mock_server.MockAccountManager{}
|
||||
am.UpdateUserPasswordFunc = func(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error {
|
||||
mockCalled = true
|
||||
return tc.mockError
|
||||
}
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
|
||||
|
||||
reqPath := "/users/" + tc.targetUserID + "/password"
|
||||
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.currentUserID,
|
||||
}
|
||||
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectMockNotCalled {
|
||||
assert.False(t, mockCalled, "mock should not have been called")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{}
|
||||
handler := newHandler(am)
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/test-user/password", bytes.NewBufferString(`{}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: existingUserID,
|
||||
}
|
||||
req = nbcontext.SetUserAuthInRequest(req, userAuth)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.changePassword(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -72,11 +74,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
ctx := context.Background()
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
|
||||
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
@@ -94,7 +99,6 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
resourcesManagerMock := resources.NewManagerMock()
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -80,11 +81,12 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
peersManager := peers.NewManager(testStore, permissionsManager)
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
|
||||
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peers.NewManager(testStore, permissionsManager)), &config.Config{})
|
||||
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{})
|
||||
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -400,7 +400,6 @@ func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email,
|
||||
|
||||
// InviteUserByID resends an invitation to a user.
|
||||
func (m *EmbeddedIdPManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
// TODO: implement
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -432,6 +431,33 @@ func (m *EmbeddedIdPManager) DeleteUser(ctx context.Context, userID string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUserPassword updates the password for a user in the embedded IdP.
|
||||
// It verifies that the current user is changing their own password and
|
||||
// validates the current password before updating to the new password.
|
||||
func (m *EmbeddedIdPManager) UpdateUserPassword(ctx context.Context, currentUserID, targetUserID string, oldPassword, newPassword string) error {
|
||||
// Verify the user is changing their own password
|
||||
if currentUserID != targetUserID {
|
||||
return fmt.Errorf("users can only change their own password")
|
||||
}
|
||||
|
||||
// Verify the new password is different from the old password
|
||||
if oldPassword == newPassword {
|
||||
return fmt.Errorf("new password must be different from current password")
|
||||
}
|
||||
|
||||
err := m.provider.UpdateUserPassword(ctx, targetUserID, oldPassword, newPassword)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("updated password for user %s in embedded IdP", targetUserID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateConnector creates a new identity provider connector in Dex.
|
||||
// Returns the created connector config with the redirect URL populated.
|
||||
func (m *EmbeddedIdPManager) CreateConnector(ctx context.Context, cfg *dex.ConnectorConfig) (*dex.ConnectorConfig, error) {
|
||||
@@ -449,15 +475,8 @@ func (m *EmbeddedIdPManager) ListConnectors(ctx context.Context) ([]*dex.Connect
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing identity provider connector.
|
||||
// Field preservation for partial updates is handled by Provider.UpdateConnector.
|
||||
func (m *EmbeddedIdPManager) UpdateConnector(ctx context.Context, cfg *dex.ConnectorConfig) error {
|
||||
// Preserve existing secret if not provided in update
|
||||
if cfg.ClientSecret == "" {
|
||||
existing, err := m.provider.GetConnector(ctx, cfg.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing connector: %w", err)
|
||||
}
|
||||
cfg.ClientSecret = existing.ClientSecret
|
||||
}
|
||||
return m.provider.UpdateConnector(ctx, cfg)
|
||||
}
|
||||
|
||||
|
||||
@@ -248,6 +248,71 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
|
||||
t.Logf(" Connector: %s", connectorID)
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Create a user with a known password
|
||||
email := "password-test@example.com"
|
||||
name := "Password Test User"
|
||||
initialPassword := "InitialPass123!"
|
||||
|
||||
userData, err := manager.CreateUserWithPassword(ctx, email, initialPassword, name)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userData)
|
||||
|
||||
userID := userData.ID
|
||||
|
||||
t.Run("successful password change", func(t *testing.T) {
|
||||
newPassword := "NewSecurePass456!"
|
||||
err := manager.UpdateUserPassword(ctx, userID, userID, initialPassword, newPassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the new password works by changing it again
|
||||
anotherPassword := "AnotherPass789!"
|
||||
err = manager.UpdateUserPassword(ctx, userID, userID, newPassword, anotherPassword)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("wrong old password", func(t *testing.T) {
|
||||
err := manager.UpdateUserPassword(ctx, userID, userID, "wrongpassword", "NewPass123!")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "current password is incorrect")
|
||||
})
|
||||
|
||||
t.Run("cannot change other user password", func(t *testing.T) {
|
||||
otherUserID := "other-user-id"
|
||||
err := manager.UpdateUserPassword(ctx, userID, otherUserID, "oldpass", "newpass")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "users can only change their own password")
|
||||
})
|
||||
|
||||
t.Run("same password rejected", func(t *testing.T) {
|
||||
samePassword := "SamePass123!"
|
||||
err := manager.UpdateUserPassword(ctx, userID, userID, samePassword, samePassword)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "new password must be different")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
59
management/server/job/channel.go
Normal file
59
management/server/job/channel.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// todo consider the channel buffer size when we allow to run multiple jobs
|
||||
const jobChannelBuffer = 1
|
||||
|
||||
var (
|
||||
ErrJobChannelClosed = errors.New("job channel closed")
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
events chan *Event
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func NewChannel() *Channel {
|
||||
jc := &Channel{
|
||||
events: make(chan *Event, jobChannelBuffer),
|
||||
}
|
||||
|
||||
return jc
|
||||
}
|
||||
|
||||
func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
// todo: timeout is handled in the wrong place. If the peer does not respond with the job response, the server does not clean it up from the pending jobs and cannot apply a new job
|
||||
case <-time.After(responseWait):
|
||||
return fmt.Errorf("failed to add the event to the channel")
|
||||
case jc.events <- event:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (jc *Channel) Close() {
|
||||
jc.once.Do(func() {
|
||||
close(jc.events)
|
||||
})
|
||||
}
|
||||
|
||||
func (jc *Channel) Event(ctx context.Context) (*Event, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case job, open := <-jc.events:
|
||||
if !open {
|
||||
return nil, ErrJobChannelClosed
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
}
|
||||
182
management/server/job/manager.go
Normal file
182
management/server/job/manager.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
PeerID string
|
||||
Request *proto.JobRequest
|
||||
Response *proto.JobResponse
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
mu *sync.RWMutex
|
||||
jobChannels map[string]*Channel // per-peer job streams
|
||||
pending map[string]*Event // jobID → event
|
||||
responseWait time.Duration
|
||||
metrics telemetry.AppMetrics
|
||||
Store store.Store
|
||||
peersManager peers.Manager
|
||||
}
|
||||
|
||||
func NewJobManager(metrics telemetry.AppMetrics, store store.Store, peersManager peers.Manager) *Manager {
|
||||
|
||||
return &Manager{
|
||||
jobChannels: make(map[string]*Channel),
|
||||
pending: make(map[string]*Event),
|
||||
responseWait: 5 * time.Minute,
|
||||
metrics: metrics,
|
||||
mu: &sync.RWMutex{},
|
||||
Store: store,
|
||||
peersManager: peersManager,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateJobChannel creates or replaces a channel for a peer
|
||||
func (jm *Manager) CreateJobChannel(ctx context.Context, accountID, peerID string) *Channel {
|
||||
// all pending jobs stored in db for this peer should be failed
|
||||
if err := jm.Store.MarkAllPendingJobsAsFailed(ctx, accountID, peerID, "Pending job cleanup: marked as failed automatically due to being stuck too long"); err != nil {
|
||||
log.WithContext(ctx).Error(err.Error())
|
||||
}
|
||||
|
||||
jm.mu.Lock()
|
||||
defer jm.mu.Unlock()
|
||||
|
||||
if ch, ok := jm.jobChannels[peerID]; ok {
|
||||
ch.Close()
|
||||
delete(jm.jobChannels, peerID)
|
||||
}
|
||||
|
||||
ch := NewChannel()
|
||||
jm.jobChannels[peerID] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
// SendJob sends a job to a peer and tracks it as pending
|
||||
func (jm *Manager) SendJob(ctx context.Context, accountID, peerID string, req *proto.JobRequest) error {
|
||||
jm.mu.RLock()
|
||||
ch, ok := jm.jobChannels[peerID]
|
||||
jm.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("peer %s has no channel", peerID)
|
||||
}
|
||||
|
||||
event := &Event{
|
||||
PeerID: peerID,
|
||||
Request: req,
|
||||
}
|
||||
|
||||
jm.mu.Lock()
|
||||
jm.pending[string(req.ID)] = event
|
||||
jm.mu.Unlock()
|
||||
|
||||
if err := ch.AddEvent(ctx, jm.responseWait, event); err != nil {
|
||||
jm.cleanup(ctx, accountID, string(req.ID), err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleResponse marks a job as finished and moves it to completed
|
||||
func (jm *Manager) HandleResponse(ctx context.Context, resp *proto.JobResponse, peerKey string) error {
|
||||
jm.mu.Lock()
|
||||
defer jm.mu.Unlock()
|
||||
|
||||
// todo: validate job ID and would be nice to use uuid text marshal instead of string
|
||||
jobID := string(resp.ID)
|
||||
|
||||
// todo: in this map has jobs for all peers in any account. Consider to validate the jobID association for the peer
|
||||
event, ok := jm.pending[jobID]
|
||||
if !ok {
|
||||
return fmt.Errorf("job %s not found", jobID)
|
||||
}
|
||||
var job types.Job
|
||||
// todo: ApplyResponse should be static. Any member value is unusable in this way
|
||||
if err := job.ApplyResponse(resp); err != nil {
|
||||
return fmt.Errorf("invalid job response: %v", err)
|
||||
}
|
||||
|
||||
peerID, err := jm.peersManager.GetPeerID(ctx, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peer ID: %v", err)
|
||||
}
|
||||
if peerID != event.PeerID {
|
||||
return fmt.Errorf("peer ID mismatch: %s != %s", peerID, event.PeerID)
|
||||
}
|
||||
|
||||
// update or create the store for job response
|
||||
err = jm.Store.CompletePeerJob(ctx, &job)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to complete job %s: %v", jobID, err)
|
||||
}
|
||||
|
||||
delete(jm.pending, jobID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseChannel closes a peer’s channel and cleans up its jobs
|
||||
func (jm *Manager) CloseChannel(ctx context.Context, accountID, peerID string) {
|
||||
jm.mu.Lock()
|
||||
defer jm.mu.Unlock()
|
||||
|
||||
if ch, ok := jm.jobChannels[peerID]; ok {
|
||||
ch.Close()
|
||||
delete(jm.jobChannels, peerID)
|
||||
}
|
||||
|
||||
for jobID, ev := range jm.pending {
|
||||
if ev.PeerID == peerID {
|
||||
// if the client disconnect and there is pending job then mark it as failed
|
||||
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, jobID, "Time out peer disconnected"); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pending jobs as failed: %v", err)
|
||||
}
|
||||
delete(jm.pending, jobID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes a pending job safely
|
||||
func (jm *Manager) cleanup(ctx context.Context, accountID, jobID string, reason string) {
|
||||
jm.mu.Lock()
|
||||
defer jm.mu.Unlock()
|
||||
|
||||
if ev, ok := jm.pending[jobID]; ok {
|
||||
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, ev.PeerID, jobID, reason); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pending jobs as failed: %v", err)
|
||||
}
|
||||
delete(jm.pending, jobID)
|
||||
}
|
||||
}
|
||||
|
||||
func (jm *Manager) IsPeerConnected(peerID string) bool {
|
||||
jm.mu.RLock()
|
||||
defer jm.mu.RUnlock()
|
||||
|
||||
_, ok := jm.jobChannels[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (jm *Manager) IsPeerHasPendingJobs(peerID string) bool {
|
||||
jm.mu.RLock()
|
||||
defer jm.mu.RUnlock()
|
||||
|
||||
for _, ev := range jm.pending {
|
||||
if ev.PeerID == peerID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -361,13 +362,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
|
||||
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
|
||||
accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "",
|
||||
accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "",
|
||||
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
if err != nil {
|
||||
@@ -381,7 +384,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -202,6 +203,8 @@ func startServer(
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(str)
|
||||
peersManager := peers.NewManager(str, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, str, peersManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
@@ -213,6 +216,7 @@ func startServer(
|
||||
nil,
|
||||
str,
|
||||
networkMapController,
|
||||
jobManager,
|
||||
nil,
|
||||
"",
|
||||
eventStore,
|
||||
@@ -237,6 +241,7 @@ func startServer(
|
||||
config,
|
||||
accountManager,
|
||||
settingsMockManager,
|
||||
jobManager,
|
||||
secretsManager,
|
||||
nil,
|
||||
nil,
|
||||
|
||||
@@ -74,6 +74,7 @@ type MockAccountManager struct {
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
UpdateUserPasswordFunc func(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -135,6 +136,29 @@ type MockAccountManager struct {
|
||||
CreateIdentityProviderFunc func(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
UpdateIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
DeleteIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) error
|
||||
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||
if am.CreatePeerJobFunc != nil {
|
||||
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CreatePeerJob is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
|
||||
if am.GetAllPeerJobsFunc != nil {
|
||||
return am.GetAllPeerJobsFunc(ctx, accountID, userID, peerID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPeerJobs is not implemented")
|
||||
}
|
||||
func (am *MockAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
|
||||
if am.GetPeerJobByIDFunc != nil {
|
||||
return am.GetPeerJobByIDFunc(ctx, accountID, userID, peerID, jobID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerJobByID is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
@@ -612,6 +636,14 @@ func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID,
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
||||
}
|
||||
|
||||
// UpdateUserPassword mocks UpdateUserPassword of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error {
|
||||
if am.UpdateUserPasswordFunc != nil {
|
||||
return am.UpdateUserPasswordFunc(ctx, accountID, currentUserID, targetUserID, oldPassword, newPassword)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateUserPassword is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.InviteUserFunc != nil {
|
||||
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -790,13 +791,14 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
|
||||
@@ -31,6 +31,8 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const remoteJobsMinVer = "0.64.0"
|
||||
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
@@ -324,6 +326,134 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.AccountID != accountID {
|
||||
return status.NewPeerNotPartOfAccountError()
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(remoteJobsMinVer, p.Meta.WtVersion)
|
||||
if !strings.Contains(p.Meta.WtVersion, "dev") && (!meetMinVer || err != nil) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer version %s does not meet the minimum required version %s for remote jobs", p.Meta.WtVersion, remoteJobsMinVer)
|
||||
}
|
||||
|
||||
if !am.jobManager.IsPeerConnected(peerID) {
|
||||
return status.Errorf(status.BadRequest, "peer not connected")
|
||||
}
|
||||
|
||||
// check if already has pending jobs
|
||||
// todo: The job checks here are not protected. The user can run this function from multiple threads,
|
||||
// and each thread can think there is no job yet. This means entries in the pending job map will be overwritten,
|
||||
// and only one will be kept, but potentially another one will overwrite it in the queue.
|
||||
if am.jobManager.IsPeerHasPendingJobs(peerID) {
|
||||
return status.Errorf(status.BadRequest, "peer already has pending job")
|
||||
}
|
||||
|
||||
jobStream, err := job.ToStreamJobRequest()
|
||||
if err != nil {
|
||||
return status.Errorf(status.BadRequest, "invalid job request %v", err)
|
||||
}
|
||||
|
||||
// try sending job first
|
||||
if err := am.jobManager.SendJob(ctx, accountID, peerID, jobStream); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to send job: %v", err)
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var eventsToStore func()
|
||||
|
||||
// persist job in DB only if send succeeded
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := transaction.CreatePeerJob(ctx, job); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jobMeta := map[string]any{
|
||||
"for_peer_name": peer.Name,
|
||||
"job_type": job.Workload.Type,
|
||||
}
|
||||
|
||||
eventsToStore = func() {
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.JobCreatedByUser, jobMeta)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventsToStore()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
|
||||
// todo: Create permissions for job
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if peerAccountID != accountID {
|
||||
return nil, status.NewPeerNotPartOfAccountError()
|
||||
}
|
||||
|
||||
accountJobs, err := am.Store.GetPeerJobs(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accountJobs, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if peerAccountID != accountID {
|
||||
return nil, status.NewPeerNotPartOfAccountError()
|
||||
}
|
||||
|
||||
job, err := am.Store.GetPeerJobByID(ctx, accountID, jobID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
|
||||
@@ -598,6 +728,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if temporary {
|
||||
// we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected
|
||||
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -1289,13 +1290,14 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
peersManager := peers.NewManager(s, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1374,13 +1376,14 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
peersManager := peers.NewManager(s, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1527,13 +1530,14 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
peersManager := peers.NewManager(s, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1607,13 +1611,14 @@ func Test_LoginPeer(t *testing.T) {
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
peersManager := peers.NewManager(s, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
@@ -3,35 +3,37 @@ package modules
|
||||
type Module string
|
||||
|
||||
const (
|
||||
Networks Module = "networks"
|
||||
Peers Module = "peers"
|
||||
Groups Module = "groups"
|
||||
Settings Module = "settings"
|
||||
Accounts Module = "accounts"
|
||||
Dns Module = "dns"
|
||||
Nameservers Module = "nameservers"
|
||||
Events Module = "events"
|
||||
Policies Module = "policies"
|
||||
Routes Module = "routes"
|
||||
Users Module = "users"
|
||||
SetupKeys Module = "setup_keys"
|
||||
Pats Module = "pats"
|
||||
Networks Module = "networks"
|
||||
Peers Module = "peers"
|
||||
RemoteJobs Module = "remote_jobs"
|
||||
Groups Module = "groups"
|
||||
Settings Module = "settings"
|
||||
Accounts Module = "accounts"
|
||||
Dns Module = "dns"
|
||||
Nameservers Module = "nameservers"
|
||||
Events Module = "events"
|
||||
Policies Module = "policies"
|
||||
Routes Module = "routes"
|
||||
Users Module = "users"
|
||||
SetupKeys Module = "setup_keys"
|
||||
Pats Module = "pats"
|
||||
IdentityProviders Module = "identity_providers"
|
||||
)
|
||||
|
||||
var All = map[Module]struct{}{
|
||||
Networks: {},
|
||||
Peers: {},
|
||||
Groups: {},
|
||||
Settings: {},
|
||||
Accounts: {},
|
||||
Dns: {},
|
||||
Nameservers: {},
|
||||
Events: {},
|
||||
Policies: {},
|
||||
Routes: {},
|
||||
Users: {},
|
||||
SetupKeys: {},
|
||||
Pats: {},
|
||||
Networks: {},
|
||||
Peers: {},
|
||||
RemoteJobs: {},
|
||||
Groups: {},
|
||||
Settings: {},
|
||||
Accounts: {},
|
||||
Dns: {},
|
||||
Nameservers: {},
|
||||
Events: {},
|
||||
Policies: {},
|
||||
Routes: {},
|
||||
Users: {},
|
||||
SetupKeys: {},
|
||||
Pats: {},
|
||||
IdentityProviders: {},
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -1289,13 +1290,14 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
Return(&types.ExtraSettings{}, nil)
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -43,14 +43,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
mysqlKeyQueryCondition = "`key` = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
mysqlKeyQueryCondition = "`key` = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountAndPeerIDQueryCondition = "account_id = ? and peer_id = ?"
|
||||
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
|
||||
pgMaxConnections = 30
|
||||
pgMinConnections = 1
|
||||
@@ -125,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&zones.Zone{}, &records.Record{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -144,6 +145,97 @@ func GetKeyQueryCondition(s *SqlStore) string {
|
||||
return keyQueryCondition
|
||||
}
|
||||
|
||||
// SaveJob persists a job in DB
|
||||
func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error {
|
||||
result := s.db.Create(job)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to create job in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CompletePeerJob(ctx context.Context, job *types.Job) error {
|
||||
result := s.db.
|
||||
Model(&types.Job{}).
|
||||
Where(idQueryCondition, job.ID).
|
||||
Updates(job)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update job in store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update job in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// job was pending for too long and has been cancelled
|
||||
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error {
|
||||
now := time.Now().UTC()
|
||||
result := s.db.
|
||||
Model(&types.Job{}).
|
||||
Where(accountAndPeerIDQueryCondition+" AND id = ?"+" AND status = ?", accountID, peerID, jobID, types.JobStatusPending).
|
||||
Updates(types.Job{
|
||||
Status: types.JobStatusFailed,
|
||||
FailedReason: reason,
|
||||
CompletedAt: &now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// job was pending for too long and has been cancelled
|
||||
func (s *SqlStore) MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error {
|
||||
now := time.Now().UTC()
|
||||
result := s.db.
|
||||
Model(&types.Job{}).
|
||||
Where(accountAndPeerIDQueryCondition+" AND status = ?", accountID, peerID, types.JobStatusPending).
|
||||
Updates(types.Job{
|
||||
Status: types.JobStatusFailed,
|
||||
FailedReason: reason,
|
||||
CompletedAt: &now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJobByID fetches job by ID
|
||||
func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) {
|
||||
var job types.Job
|
||||
err := s.db.
|
||||
Where(accountAndIDQueryCondition, accountID, jobID).
|
||||
First(&job).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "job %s not found", jobID)
|
||||
}
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to fetch job from store: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// get all jobs
|
||||
func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) {
|
||||
var jobs []*types.Job
|
||||
err := s.db.
|
||||
Where(accountAndPeerIDQueryCondition, accountID, peerID).
|
||||
Order("created_at DESC").
|
||||
Find(&jobs).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring global lock")
|
||||
@@ -4363,3 +4455,23 @@ func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID s
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var peerID string
|
||||
result := tx.Model(&nbpeer.Peer{}).
|
||||
Select("id").
|
||||
Where(GetKeyQueryCondition(s), key).
|
||||
Limit(1).
|
||||
Scan(&peerID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer ID by key: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "failed to get peer ID by key")
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
@@ -226,6 +226,13 @@ type Store interface {
|
||||
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
|
||||
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
|
||||
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
|
||||
CreatePeerJob(ctx context.Context, job *types.Job) error
|
||||
CompletePeerJob(ctx context.Context, job *types.Job) error
|
||||
GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error)
|
||||
GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error)
|
||||
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error
|
||||
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
||||
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user