diff --git a/client/cmd/debug.go b/client/cmd/debug.go index bbb0ef0d6..e480df4d7 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -219,11 +219,33 @@ func runForDuration(cmd *cobra.Command, args []string) error { time.Sleep(3 * time.Second) + cpuProfilingStarted := false + if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil { + cmd.PrintErrf("Failed to start CPU profiling: %v\n", err) + } else { + cpuProfilingStarted = true + defer func() { + if cpuProfilingStarted { + if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { + cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) + } + } + }() + } + if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr } cmd.Println("\nDuration completed") + if cpuProfilingStarted { + if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { + cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) + } else { + cpuProfilingStarted = false + } + } + cmd.Println("Creating debug bundle...") request := &proto.DebugBundleRequest{ @@ -353,6 +375,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c StatusRecorder: recorder, SyncResponse: syncResponse, LogPath: logFilePath, + CPUProfile: nil, }, debug.BundleConfig{ IncludeSystemInfo: true, diff --git a/client/iface/bind/dual_stack_conn.go b/client/iface/bind/dual_stack_conn.go new file mode 100644 index 000000000..061016ecc --- /dev/null +++ b/client/iface/bind/dual_stack_conn.go @@ -0,0 +1,169 @@ +package bind + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +var ( + errNoIPv4Conn = errors.New("no IPv4 connection available") + errNoIPv6Conn = errors.New("no IPv6 connection available") + errInvalidAddr = errors.New("invalid address type") +) + +// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes +// to the appropriate connection based on the destination address. +// ReadFrom is not used in the hot path - ICEBind receives packets via +// BatchReader.ReadBatch() directly. This is only used by udpMux for sending. +type DualStackPacketConn struct { + ipv4Conn net.PacketConn + ipv6Conn net.PacketConn + + readFromWarn sync.Once +} + +// NewDualStackPacketConn creates a new dual-stack packet connection. +func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn { + return &DualStackPacketConn{ + ipv4Conn: ipv4Conn, + ipv6Conn: ipv6Conn, + } +} + +// ReadFrom reads from the available connection (preferring IPv4). +// NOTE: This method is NOT used in the data path. ICEBind receives packets via +// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient. +// This implementation exists only to satisfy the net.PacketConn interface for the udpMux, +// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom() +// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path. +func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + d.readFromWarn.Do(func() { + log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path") + }) + + if d.ipv4Conn != nil { + return d.ipv4Conn.ReadFrom(b) + } + if d.ipv6Conn != nil { + return d.ipv6Conn.ReadFrom(b) + } + return 0, nil, net.ErrClosed +} + +// WriteTo writes to the appropriate connection based on the address type. +func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, &net.OpError{ + Op: "write", + Net: "udp", + Addr: addr, + Err: errInvalidAddr, + } + } + + if udpAddr.IP.To4() == nil { + if d.ipv6Conn != nil { + return d.ipv6Conn.WriteTo(b, addr) + } + return 0, &net.OpError{ + Op: "write", + Net: "udp6", + Addr: addr, + Err: errNoIPv6Conn, + } + } + + if d.ipv4Conn != nil { + return d.ipv4Conn.WriteTo(b, addr) + } + return 0, &net.OpError{ + Op: "write", + Net: "udp4", + Addr: addr, + Err: errNoIPv4Conn, + } +} + +// Close closes both connections. +func (d *DualStackPacketConn) Close() error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.Close(); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.Close(); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// LocalAddr returns the local address of the IPv4 connection if available, +// otherwise the IPv6 connection. +func (d *DualStackPacketConn) LocalAddr() net.Addr { + if d.ipv4Conn != nil { + return d.ipv4Conn.LocalAddr() + } + if d.ipv6Conn != nil { + return d.ipv6Conn.LocalAddr() + } + return nil +} + +// SetDeadline sets the deadline for both connections. +func (d *DualStackPacketConn) SetDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// SetReadDeadline sets the read deadline for both connections. +func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetReadDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetReadDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// SetWriteDeadline sets the write deadline for both connections. +func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetWriteDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetWriteDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} diff --git a/client/iface/bind/dual_stack_conn_bench_test.go b/client/iface/bind/dual_stack_conn_bench_test.go new file mode 100644 index 000000000..940c44966 --- /dev/null +++ b/client/iface/bind/dual_stack_conn_bench_test.go @@ -0,0 +1,119 @@ +package bind + +import ( + "net" + "testing" +) + +var ( + ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} + ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345} + payload = make([]byte, 1200) +) + +func BenchmarkWriteTo_DirectUDPConn(b *testing.B) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = conn.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + ds := NewDualStackPacketConn(conn, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) { + conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn.Close() + + ds := NewDualStackPacketConn(nil, conn) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv6Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv6Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + addrs := []net.Addr{ipv4Addr, ipv6Addr} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, addrs[i&1]) + } +} diff --git a/client/iface/bind/dual_stack_conn_test.go b/client/iface/bind/dual_stack_conn_test.go new file mode 100644 index 000000000..3007d907f --- /dev/null +++ b/client/iface/bind/dual_stack_conn_test.go @@ -0,0 +1,191 @@ +package bind + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) { + ipv4Conn := &mockPacketConn{network: "udp4"} + ipv6Conn := &mockPacketConn{network: "udp6"} + dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn) + + tests := []struct { + name string + addr *net.UDPAddr + wantSocket string + }{ + { + name: "IPv4 address", + addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv6 address", + addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + wantSocket: "udp6", + }, + { + name: "IPv4-mapped IPv6 goes to IPv4", + addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv4 loopback", + addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv6 loopback", + addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234}, + wantSocket: "udp6", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ipv4Conn.writeCount = 0 + ipv6Conn.writeCount = 0 + + n, err := dualStack.WriteTo([]byte("test"), tt.addr) + require.NoError(t, err) + assert.Equal(t, 4, n) + + if tt.wantSocket == "udp4" { + assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4") + assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6") + } else { + assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4") + assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6") + } + }) + } +} + +func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) { + dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil) + + // IPv4 works + _, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}) + require.NoError(t, err) + + // IPv6 fails + _, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no IPv6 connection") +} + +func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) { + dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"}) + + // IPv6 works + _, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}) + require.NoError(t, err) + + // IPv4 fails + _, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no IPv4 connection") +} + +// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom +// only reads from one socket (IPv4 preferred). This is fine because the actual +// receive path uses wireguard-go's BatchReader directly, not ReadFrom. +func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) { + ipv4Conn := &mockPacketConn{ + network: "udp4", + readData: []byte("from ipv4"), + readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}, + } + ipv6Conn := &mockPacketConn{ + network: "udp6", + readData: []byte("from ipv6"), + readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + } + + dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn) + + buf := make([]byte, 100) + n, addr, err := dualStack.ReadFrom(buf) + + require.NoError(t, err) + // reads from IPv4 (preferred) - this is expected behavior + assert.Equal(t, "from ipv4", string(buf[:n])) + assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String()) +} + +func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) { + ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820} + ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820} + + tests := []struct { + name string + ipv4 net.PacketConn + ipv6 net.PacketConn + wantAddr net.Addr + }{ + { + name: "both available returns IPv4", + ipv4: &mockPacketConn{localAddr: ipv4Addr}, + ipv6: &mockPacketConn{localAddr: ipv6Addr}, + wantAddr: ipv4Addr, + }, + { + name: "IPv4 only", + ipv4: &mockPacketConn{localAddr: ipv4Addr}, + ipv6: nil, + wantAddr: ipv4Addr, + }, + { + name: "IPv6 only", + ipv4: nil, + ipv6: &mockPacketConn{localAddr: ipv6Addr}, + wantAddr: ipv6Addr, + }, + { + name: "neither returns nil", + ipv4: nil, + ipv6: nil, + wantAddr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6) + assert.Equal(t, tt.wantAddr, dualStack.LocalAddr()) + }) + } +} + +// mock + +type mockPacketConn struct { + network string + writeCount int + readData []byte + readAddr net.Addr + localAddr net.Addr +} + +func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + if m.readData != nil { + return copy(b, m.readData), m.readAddr, nil + } + return 0, nil, nil +} + +func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + m.writeCount++ + return len(b), nil +} + +func (m *mockPacketConn) Close() error { return nil } +func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr } +func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil } +func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 0957d2dd5..bf79ecd79 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -14,7 +14,6 @@ import ( "github.com/pion/stun/v3" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" @@ -28,22 +27,7 @@ type receiverCreator struct { } func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { - if ipv4PC, ok := pc.(*ipv4.PacketConn); ok { - return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool) - } - // IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the - // wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6. - return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - buf := bufs[0] - size, ep, err := conn.ReadFromUDPAddrPort(buf) - if err != nil { - return 0, err - } - sizes[0] = size - stdEp := &wgConn.StdNetEndpoint{AddrPort: ep} - eps[0] = stdEp - return 1, nil - } + return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool) } // ICEBind is a bind implementation with two main features: @@ -73,6 +57,8 @@ type ICEBind struct { muUDPMux sync.Mutex udpMux *udpmux.UniversalUDPMuxDefault + ipv4Conn *net.UDPConn + ipv6Conn *net.UDPConn } func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { @@ -118,6 +104,12 @@ func (s *ICEBind) Close() error { close(s.closedChan) + s.muUDPMux.Lock() + s.ipv4Conn = nil + s.ipv6Conn = nil + s.udpMux = nil + s.muUDPMux.Unlock() + return s.StdNetBind.Close() } @@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { +func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = udpmux.NewUniversalUDPMuxDefault( - udpmux.UniversalUDPMuxParams{ - UDPConn: nbnet.WrapPacketConn(conn), - Net: s.transportNet, - FilterFn: s.filterFn, - WGAddress: s.address, - MTU: s.mtu, - }, - ) + // Detect IPv4 vs IPv6 from connection's local address + if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil { + s.ipv4Conn = conn + } else { + s.ipv6Conn = conn + } + s.createOrUpdateMux() + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { msgs := getMessages(msgsPool) for i := range bufs { @@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } defer putMessages(msgs, msgsPool) + var numMsgs int if runtime.GOOS == "linux" || runtime.GOOS == "android" { if rxOffload { readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams) - //nolint - numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0) + //nolint:staticcheck + _, err = pc.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } @@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r } numMsgs = 1 } + for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] // todo: handle err - ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) - if ok { + if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok { continue } sizes[i] = msg.N @@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r } } +// createOrUpdateMux creates or updates the UDP mux with the available connections. +// Must be called with muUDPMux held. +func (s *ICEBind) createOrUpdateMux() { + var muxConn net.PacketConn + + switch { + case s.ipv4Conn != nil && s.ipv6Conn != nil: + muxConn = NewDualStackPacketConn( + nbnet.WrapPacketConn(s.ipv4Conn), + nbnet.WrapPacketConn(s.ipv6Conn), + ) + case s.ipv4Conn != nil: + muxConn = nbnet.WrapPacketConn(s.ipv4Conn) + case s.ipv6Conn != nil: + muxConn = nbnet.WrapPacketConn(s.ipv6Conn) + default: + return + } + + // Don't close the old mux - it doesn't own the underlying connections. + // The sockets are managed by WireGuard's StdNetBind, not by us. + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ + UDPConn: muxConn, + Net: s.transportNet, + FilterFn: s.filterFn, + WGAddress: s.address, + MTU: s.mtu, + }, + ) +} + func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { for i := range buffers { if !stun.IsMessage(buffers[i]) { @@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) return true, err } - muxErr := s.udpMux.HandleSTUNMessage(msg, addr) - if muxErr != nil { - log.Warnf("failed to handle STUN packet") + s.muUDPMux.Lock() + mux := s.udpMux + s.muUDPMux.Unlock() + + if mux != nil { + if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil { + log.Warnf("failed to handle STUN packet: %v", muxErr) + } } buffers[i] = []byte{} diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go new file mode 100644 index 000000000..1fdd955c9 --- /dev/null +++ b/client/iface/bind/ice_bind_test.go @@ -0,0 +1,324 @@ +package bind + +import ( + "fmt" + "net" + "net/netip" + "sync" + "testing" + "time" + + "github.com/pion/transport/v3/stdnet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) { + iceBind := setupICEBind(t) + + ipv4Conn, ipv6Conn := createDualStackConns(t) + defer ipv4Conn.Close() + defer ipv6Conn.Close() + + rc := receiverCreator{iceBind} + pool := createMsgPool() + + // Simulate wireguard-go calling CreateReceiverFn for IPv4 + ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool) + require.NotNil(t, ipv4RecvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection") + assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet") + assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection") + iceBind.muUDPMux.Unlock() + + // Simulate wireguard-go calling CreateReceiverFn for IPv6 + ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool) + require.NotNil(t, ipv6RecvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection") + assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection") + assert.NotNil(t, iceBind.udpMux, "mux should still exist") + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +func TestICEBind_WorksWithIPv4Only(t *testing.T) { + iceBind := setupICEBind(t) + + ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + defer ipv4Conn.Close() + + rc := receiverCreator{iceBind} + recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool()) + require.NotNil(t, recvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn) + assert.Nil(t, iceBind.ipv6Conn) + assert.NotNil(t, iceBind.udpMux) + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +func TestICEBind_WorksWithIPv6Only(t *testing.T) { + iceBind := setupICEBind(t) + + ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Conn.Close() + + rc := receiverCreator{iceBind} + recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool()) + require.NotNil(t, recvFn) + + iceBind.muUDPMux.Lock() + assert.Nil(t, iceBind.ipv4Conn) + assert.NotNil(t, iceBind.ipv6Conn) + assert.NotNil(t, iceBind.udpMux) + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate +// with peers on different address families through the same DualStackPacketConn. +func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) { + // two "remote peers" listening on different address families + ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Peer.Close() + + ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Peer.Close() + + // our local dual-stack connection + ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Local.Close() + + ipv6Local := listenUDP(t, "udp6", "[::1]:0") + defer ipv6Local.Close() + + dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local) + + // send to both peers + _, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr()) + require.NoError(t, err) + + _, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr()) + require.NoError(t, err) + + // verify IPv4 peer got its packet from the IPv4 socket + buf := make([]byte, 100) + _ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second)) + n, addr, err := ipv4Peer.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, "to-ipv4", string(buf[:n])) + assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port) + + // verify IPv6 peer got its packet from the IPv6 socket + _ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second)) + n, addr, err = ipv6Peer.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, "to-ipv6", string(buf[:n])) + assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port) +} + +// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4 +// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets, +// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP. +func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { + ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Peer.Close() + + ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Peer.Close() + + ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Local.Close() + + ipv6Local := listenUDP(t, "udp6", "[::1]:0") + defer ipv6Local.Close() + + dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local) + + const packetsPerFamily = 500 + + ipv4Received := make(chan string, packetsPerFamily) + ipv6Received := make(chan string, packetsPerFamily) + + startGate := make(chan struct{}) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + for i := 0; i < packetsPerFamily; i++ { + n, _, err := ipv4Peer.ReadFrom(buf) + if err != nil { + return + } + ipv4Received <- string(buf[:n]) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + for i := 0; i < packetsPerFamily; i++ { + n, _, err := ipv6Peer.ReadFrom(buf) + if err != nil { + return + } + ipv6Received <- string(buf[:n]) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-startGate + for i := 0; i < packetsPerFamily; i++ { + _, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr()) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-startGate + for i := 0; i < packetsPerFamily; i++ { + _, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr()) + } + }() + + close(startGate) + + time.AfterFunc(5*time.Second, func() { + _ = ipv4Peer.SetReadDeadline(time.Now()) + _ = ipv6Peer.SetReadDeadline(time.Now()) + }) + + wg.Wait() + close(ipv4Received) + close(ipv6Received) + + ipv4Count := 0 + for pkt := range ipv4Received { + require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt) + ipv4Count++ + } + + ipv6Count := 0 + for pkt := range ipv6Received { + require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt) + ipv6Count++ + } + + assert.Equal(t, packetsPerFamily, ipv4Count) + assert.Equal(t, packetsPerFamily, ipv6Count) +} + +func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { + tests := []struct { + name string + network string + addr string + wantIPv4 bool + }{ + {"IPv4 any", "udp4", "0.0.0.0:0", true}, + {"IPv4 loopback", "udp4", "127.0.0.1:0", true}, + {"IPv6 any", "udp6", "[::]:0", false}, + {"IPv6 loopback", "udp6", "[::1]:0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, err := net.ResolveUDPAddr(tt.network, tt.addr) + require.NoError(t, err) + + conn, err := net.ListenUDP(tt.network, addr) + if err != nil { + t.Skipf("%s not available: %v", tt.network, err) + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + isIPv4 := localAddr.IP.To4() != nil + assert.Equal(t, tt.wantIPv4, isIPv4) + }) + } +} + +// helpers + +func setupICEBind(t *testing.T) *ICEBind { + t.Helper() + transportNet, err := stdnet.NewNet() + require.NoError(t, err) + + address := wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/10"), + } + return NewICEBind(transportNet, nil, address, 1280) +} + +func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) { + t.Helper() + ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + + ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + ipv4Conn.Close() + t.Skipf("IPv6 not available: %v", err) + } + return ipv4Conn, ipv6Conn +} + +func createMsgPool() *sync.Pool { + return &sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, 1) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, 40) + } + return &msgs + }, + } +} + +func listenUDP(t *testing.T, network, addr string) *net.UDPConn { + t.Helper() + udpAddr, err := net.ResolveUDPAddr(network, addr) + require.NoError(t, err) + conn, err := net.ListenUDP(network, udpAddr) + require.NoError(t, err) + return conn +} diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index d3b5bc9d4..07a19036a 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -59,6 +59,7 @@ block.prof: Block profiling information. heap.prof: Heap profiling information (snapshot of memory allocations). allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. +cpu.prof: CPU profiling information. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. @@ -226,6 +227,7 @@ type BundleGenerator struct { statusRecorder *peer.Status syncResponse *mgmProto.SyncResponse logPath string + cpuProfile []byte anonymize bool includeSystemInfo bool @@ -245,6 +247,7 @@ type GeneratorDependencies struct { StatusRecorder *peer.Status SyncResponse *mgmProto.SyncResponse LogPath string + CPUProfile []byte } func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { @@ -261,6 +264,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen statusRecorder: deps.StatusRecorder, syncResponse: deps.SyncResponse, logPath: deps.LogPath, + cpuProfile: deps.CPUProfile, anonymize: cfg.Anonymize, includeSystemInfo: cfg.IncludeSystemInfo, @@ -324,6 +328,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add profiles to debug bundle: %v", err) } + if err := g.addCPUProfile(); err != nil { + log.Errorf("failed to add CPU profile to debug bundle: %v", err) + } + if err := g.addStackTrace(); err != nil { log.Errorf("failed to add stack trace to debug bundle: %v", err) } @@ -542,6 +550,19 @@ func (g *BundleGenerator) addProf() (err error) { return nil } +func (g *BundleGenerator) addCPUProfile() error { + if len(g.cpuProfile) == 0 { + return nil + } + + reader := bytes.NewReader(g.cpuProfile) + if err := g.addFileToZip(reader, "cpu.prof"); err != nil { + return fmt.Errorf("add CPU profile to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addStackTrace() error { buf := make([]byte, 5242880) // 5 MB buffer n := runtime.Stack(buf, true) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 840fc9241..b6b9d2cf4 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "time" @@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent RosenpassAddr: remoteOfferAnswer.RosenpassAddr, LocalIceCandidateType: pair.Local.Type().String(), RemoteIceCandidateType: pair.Remote.Type().String(), - LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), - RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), + LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())), + RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())), Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), } @@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) - addrString := pair.Remote.Address() - parsed, err := netip.ParseAddr(addrString) - if (err == nil) && (parsed.Is6()) { - addrString = fmt.Sprintf("[%s]", addrString) - //IPv6 Literals need to be wrapped in brackets for Resolve*Addr() - } - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort)) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort))) if err != nil { w.log.Warnf("got an error while resolving the udp address, err: %s", err) return @@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, } } +func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { + sessionID := w.SessionID() + stats := agent.GetCandidatePairsStats() + localCandidates, _ := agent.GetLocalCandidates() + remoteCandidates, _ := agent.GetRemoteCandidates() + + localMap := make(map[string]ice.Candidate) + for _, c := range localCandidates { + localMap[c.ID()] = c + } + remoteMap := make(map[string]ice.Candidate) + for _, c := range remoteCandidates { + remoteMap[c.ID()] = c + } + + for _, stat := range stats { + if stat.State == ice.CandidatePairStateSucceeded { + local, lok := localMap[stat.LocalCandidateID] + remote, rok := remoteMap[stat.RemoteCandidateID] + if !lok || !rok { + continue + } + w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + sessionID, + local.NetworkType(), local.Type(), local.Address(), + remote.NetworkType(), remote.Type(), remote.Address(), + stat.CurrentRoundTripTime*1000) + } + } +} + func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { return func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) switch state { case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected + w.logSuccessfulPaths(agent) return case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 9cbe34e1d..1d9d7233c 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.33.1 +// protoc v6.32.1 // source: daemon.proto package proto @@ -5364,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { return 0 } +// StartCPUProfileRequest for starting CPU profiling +type StartCPUProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCPUProfileRequest) Reset() { + *x = StartCPUProfileRequest{} + mi := &file_daemon_proto_msgTypes[79] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCPUProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCPUProfileRequest) ProtoMessage() {} + +func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[79] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. +func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{79} +} + +// StartCPUProfileResponse confirms CPU profiling has started +type StartCPUProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCPUProfileResponse) Reset() { + *x = StartCPUProfileResponse{} + mi := &file_daemon_proto_msgTypes[80] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCPUProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCPUProfileResponse) ProtoMessage() {} + +func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[80] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. +func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{80} +} + +// StopCPUProfileRequest for stopping CPU profiling +type StopCPUProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopCPUProfileRequest) Reset() { + *x = StopCPUProfileRequest{} + mi := &file_daemon_proto_msgTypes[81] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopCPUProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopCPUProfileRequest) ProtoMessage() {} + +func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[81] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. +func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{81} +} + +// StopCPUProfileResponse confirms CPU profiling has stopped +type StopCPUProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopCPUProfileResponse) Reset() { + *x = StopCPUProfileResponse{} + mi := &file_daemon_proto_msgTypes[82] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopCPUProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopCPUProfileResponse) ProtoMessage() {} + +func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[82] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. +func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{82} +} + type InstallerResultRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -5372,7 +5520,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5384,7 +5532,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5397,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{83} } type InstallerResultResponse struct { @@ -5410,7 +5558,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5422,7 +5570,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5435,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{84} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5462,7 +5610,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5474,7 +5622,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5994,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" + + "\x16StartCPUProfileRequest\"\x19\n" + + "\x17StartCPUProfileResponse\"\x17\n" + + "\x15StopCPUProfileRequest\"\x18\n" + + "\x16StopCPUProfileResponse\"\x18\n" + "\x16InstallerResultRequest\"O\n" + "\x17InstallerResultResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + @@ -6006,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\xb4\x13\n" + + "\x05TRACE\x10\a2\xdd\x14\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6041,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + - "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" @@ -6058,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType @@ -6143,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{ (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse - (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse - nil, // 85: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range - nil, // 87: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 88: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp + (*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse + nil, // 89: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range + nil, // 91: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 92: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState @@ -6168,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{ 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -6180,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{ 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest @@ -6217,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{ 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 67, // [67:100] is the sub-list for method output_type - 34, // [34:67] is the sub-list for method input_type + 83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 69, // [69:104] is the sub-list for method output_type + 34, // [34:69] is the sub-list for method input_type 34, // [34:34] is the sub-list for extension type_name 34, // [34:34] is the sub-list for extension extendee 0, // [0:34] is the sub-list for field type_name @@ -6283,7 +6445,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 4, - NumMessages: 84, + NumMessages: 88, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 7a802d830..68b9a9348 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -94,6 +94,12 @@ service DaemonService { // WaitJWTToken waits for JWT authentication completion rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} +// StartCPUProfile starts CPU profiling in the daemon + rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {} + + // StopCPUProfile stops CPU profiling in the daemon + rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {} + rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} @@ -776,6 +782,18 @@ message WaitJWTTokenResponse { int64 expiresIn = 3; } +// StartCPUProfileRequest for starting CPU profiling +message StartCPUProfileRequest {} + +// StartCPUProfileResponse confirms CPU profiling has started +message StartCPUProfileResponse {} + +// StopCPUProfileRequest for stopping CPU profiling +message StopCPUProfileRequest {} + +// StopCPUProfileResponse confirms CPU profiling has stopped +message StopCPUProfileResponse {} + message InstallerResultRequest { } diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index fdabb1879..ea9b4df05 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -70,6 +70,10 @@ type DaemonServiceClient interface { RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) + // StartCPUProfile starts CPU profiling in the daemon + StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) + // StopCPUProfile stops CPU profiling in the daemon + StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) } @@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken return out, nil } +func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) { + out := new(StartCPUProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) { + out := new(StopCPUProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { out := new(OSLifecycleResponse) err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) @@ -458,6 +480,10 @@ type DaemonServiceServer interface { RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) + // StartCPUProfile starts CPU profiling in the daemon + StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) + // StopCPUProfile stops CPU profiling in the daemon + StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) mustEmbedUnimplementedDaemonServiceServer() @@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") } +func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented") +} +func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented") +} func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") } @@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartCPUProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StartCPUProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/StartCPUProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopCPUProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StopCPUProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/StopCPUProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(OSLifecycleRequest) if err := dec(in); err != nil { @@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "WaitJWTToken", Handler: _DaemonService_WaitJWTToken_Handler, }, + { + MethodName: "StartCPUProfile", + Handler: _DaemonService_StartCPUProfile_Handler, + }, + { + MethodName: "StopCPUProfile", + Handler: _DaemonService_StopCPUProfile_Handler, + }, { MethodName: "NotifyOSLifecycle", Handler: _DaemonService_NotifyOSLifecycle_Handler, diff --git a/client/server/debug.go b/client/server/debug.go index 104fd30f4..5646cea79 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -3,9 +3,11 @@ package server import ( + "bytes" "context" "errors" "fmt" + "runtime/pprof" log "github.com/sirupsen/logrus" @@ -24,12 +26,21 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( log.Warnf("failed to get latest sync response: %v", err) } + var cpuProfileData []byte + if s.cpuProfileBuf != nil && !s.cpuProfiling { + cpuProfileData = s.cpuProfileBuf.Bytes() + defer func() { + s.cpuProfileBuf = nil + }() + } + bundleGenerator := debug.NewBundleGenerator( debug.GeneratorDependencies{ InternalConfig: s.config, StatusRecorder: s.statusRecorder, SyncResponse: syncResponse, LogPath: s.logFile, + CPUProfile: cpuProfileData, }, debug.BundleConfig{ Anonymize: req.GetAnonymize(), @@ -109,3 +120,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) { return cClient.GetLatestSyncResponse() } + +// StartCPUProfile starts CPU profiling in the daemon. +func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.cpuProfiling { + return nil, fmt.Errorf("CPU profiling already in progress") + } + + s.cpuProfileBuf = &bytes.Buffer{} + s.cpuProfiling = true + if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil { + s.cpuProfileBuf = nil + s.cpuProfiling = false + return nil, fmt.Errorf("start CPU profile: %w", err) + } + + log.Info("CPU profiling started") + return &proto.StartCPUProfileResponse{}, nil +} + +// StopCPUProfile stops CPU profiling in the daemon. +func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.cpuProfiling { + return nil, fmt.Errorf("CPU profiling not in progress") + } + + pprof.StopCPUProfile() + s.cpuProfiling = false + + if s.cpuProfileBuf != nil { + log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len()) + } + + return &proto.StopCPUProfileResponse{}, nil +} diff --git a/client/server/server.go b/client/server/server.go index 408bd56db..e3c95077a 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "errors" "fmt" @@ -77,6 +78,9 @@ type Server struct { persistSyncResponse bool isSessionActive atomic.Bool + cpuProfileBuf *bytes.Buffer + cpuProfiling bool + profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool diff --git a/client/status/status.go b/client/status/status.go index be28ff67d..f13163a41 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -491,6 +491,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total) + var forwardingRulesString string + if o.NumberOfForwardingRules > 0 { + forwardingRulesString = fmt.Sprintf("Forwarding rules: %d\n", o.NumberOfForwardingRules) + } + goos := runtime.GOOS goarch := runtime.GOARCH goarm := "" @@ -514,7 +519,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS "Lazy connection: %s\n"+ "SSH Server: %s\n"+ "Networks: %s\n"+ - "Forwarding rules: %d\n"+ + "%s"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), o.DaemonVersion, @@ -531,7 +536,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS lazyConnectionEnabledStatus, sshServerStatus, networks, - o.NumberOfForwardingRules, + forwardingRulesString, peersCountString, ) return summary diff --git a/client/status/status_test.go b/client/status/status_test.go index ad158722b..b02d78d64 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -567,7 +567,6 @@ Quantum resistance: false Lazy connection: false SSH Server: Disabled Networks: 10.10.0.0/24 -Forwarding rules: 0 Peers count: 2/2 Connected `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) @@ -592,7 +591,6 @@ Quantum resistance: false Lazy connection: false SSH Server: Disabled Networks: 10.10.0.0/24 -Forwarding rules: 0 Peers count: 2/2 Connected ` diff --git a/client/ui/debug.go b/client/ui/debug.go index e9bcfde41..29f73a66a 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -406,6 +406,10 @@ func (s *serviceClient) configureServiceForDebug( } time.Sleep(time.Second * 3) + if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { + log.Warnf("failed to start CPU profiling: %v", err) + } + return nil } @@ -428,6 +432,10 @@ func (s *serviceClient) collectDebugData( progress.progressBar.Hide() progress.statusLabel.SetText("Collecting debug data...") + if _, err := conn.StopCPUProfile(s.ctx, &proto.StopCPUProfileRequest{}); err != nil { + log.Warnf("failed to stop CPU profiling: %v", err) + } + return nil } diff --git a/go.mod b/go.mod index cb16fff52..8ac5613ee 100644 --- a/go.mod +++ b/go.mod @@ -68,7 +68,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 diff --git a/go.sum b/go.sum index c59acbb23..6adc7f7e8 100644 --- a/go.sum +++ b/go.sum @@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/activity/store/crypt.go b/management/server/activity/store/crypt.go deleted file mode 100644 index ce97347d4..000000000 --- a/management/server/activity/store/crypt.go +++ /dev/null @@ -1,136 +0,0 @@ -package store - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "errors" -) - -var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} - -type FieldEncrypt struct { - block cipher.Block - gcm cipher.AEAD -} - -func GenerateKey() (string, error) { - key := make([]byte, 32) - _, err := rand.Read(key) - if err != nil { - return "", err - } - readableKey := base64.StdEncoding.EncodeToString(key) - return readableKey, nil -} - -func NewFieldEncrypt(key string) (*FieldEncrypt, error) { - binKey, err := base64.StdEncoding.DecodeString(key) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(binKey) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - ec := &FieldEncrypt{ - block: block, - gcm: gcm, - } - - return ec, nil -} - -func (ec *FieldEncrypt) LegacyEncrypt(payload string) string { - plainText := pkcs5Padding([]byte(payload)) - cipherText := make([]byte, len(plainText)) - cbc := cipher.NewCBCEncrypter(ec.block, iv) - cbc.CryptBlocks(cipherText, plainText) - return base64.StdEncoding.EncodeToString(cipherText) -} - -// Encrypt encrypts plaintext using AES-GCM -func (ec *FieldEncrypt) Encrypt(payload string) (string, error) { - plaintext := []byte(payload) - nonceSize := ec.gcm.NonceSize() - - nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead()) - if _, err := rand.Read(nonce); err != nil { - return "", err - } - - ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil) - - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) { - cipherText, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - cbc := cipher.NewCBCDecrypter(ec.block, iv) - cbc.CryptBlocks(cipherText, cipherText) - payload, err := pkcs5UnPadding(cipherText) - if err != nil { - return "", err - } - - return string(payload), nil -} - -// Decrypt decrypts ciphertext using AES-GCM -func (ec *FieldEncrypt) Decrypt(data string) (string, error) { - cipherText, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - - nonceSize := ec.gcm.NonceSize() - if len(cipherText) < nonceSize { - return "", errors.New("cipher text too short") - } - - nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:] - plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil) - if err != nil { - return "", err - } - - return string(plainText), nil -} - -func pkcs5Padding(ciphertext []byte) []byte { - padding := aes.BlockSize - len(ciphertext)%aes.BlockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(ciphertext, padText...) -} -func pkcs5UnPadding(src []byte) ([]byte, error) { - srcLen := len(src) - if srcLen == 0 { - return nil, errors.New("input data is empty") - } - - paddingLen := int(src[srcLen-1]) - if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen { - return nil, errors.New("invalid padding size") - } - - // Verify that all padding bytes are the same - for i := 0; i < paddingLen; i++ { - if src[srcLen-1-i] != byte(paddingLen) { - return nil, errors.New("invalid padding") - } - } - - return src[:srcLen-paddingLen], nil -} diff --git a/management/server/activity/store/crypt_test.go b/management/server/activity/store/crypt_test.go deleted file mode 100644 index 700bbcd6b..000000000 --- a/management/server/activity/store/crypt_test.go +++ /dev/null @@ -1,310 +0,0 @@ -package store - -import ( - "bytes" - "testing" -) - -func TestGenerateKey(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted, err := ee.Encrypt(testData) - if err != nil { - t.Fatalf("failed to encrypt data: %s", err) - } - - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - decrypted, err := ee.Decrypt(encrypted) - if err != nil { - t.Fatalf("failed to decrypt data: %s", err) - } - - if decrypted != testData { - t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) - } -} - -func TestGenerateKeyLegacy(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted := ee.LegacyEncrypt(testData) - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - decrypted, err := ee.LegacyDecrypt(encrypted) - if err != nil { - t.Fatalf("failed to decrypt data: %s", err) - } - - if decrypted != testData { - t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) - } -} - -func TestCorruptKey(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted, err := ee.Encrypt(testData) - if err != nil { - t.Fatalf("failed to encrypt data: %s", err) - } - - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - newKey, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - - ee, err = NewFieldEncrypt(newKey) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - res, _ := ee.Decrypt(encrypted) - if res == testData { - t.Fatalf("incorrect decryption, the result is: %s", res) - } -} - -func TestEncryptDecrypt(t *testing.T) { - // Generate a key for encryption/decryption - key, err := GenerateKey() - if err != nil { - t.Fatalf("Failed to generate key: %v", err) - } - - // Initialize the FieldEncrypt with the generated key - ec, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("Failed to create FieldEncrypt: %v", err) - } - - // Test cases - testCases := []struct { - name string - input string - }{ - { - name: "Empty String", - input: "", - }, - { - name: "Short String", - input: "Hello", - }, - { - name: "String with Spaces", - input: "Hello, World!", - }, - { - name: "Long String", - input: "The quick brown fox jumps over the lazy dog.", - }, - { - name: "Unicode Characters", - input: "こんにちは世界", - }, - { - name: "Special Characters", - input: "!@#$%^&*()_+-=[]{}|;':\",./<>?", - }, - { - name: "Numeric String", - input: "1234567890", - }, - { - name: "Repeated Characters", - input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - }, - { - name: "Multi-block String", - input: "This is a longer string that will span multiple blocks in the encryption algorithm.", - }, - { - name: "Non-ASCII and ASCII Mix", - input: "Hello 世界 123", - }, - } - - for _, tc := range testCases { - t.Run(tc.name+" - Legacy", func(t *testing.T) { - // Legacy Encryption - encryptedLegacy := ec.LegacyEncrypt(tc.input) - if encryptedLegacy == "" { - t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input) - } - - // Legacy Decryption - decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy) - if err != nil { - t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err) - } - - // Verify that the decrypted value matches the original input - if decryptedLegacy != tc.input { - t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input) - } - }) - - t.Run(tc.name+" - New", func(t *testing.T) { - // New Encryption - encryptedNew, err := ec.Encrypt(tc.input) - if err != nil { - t.Errorf("Encrypt failed for input '%s': %v", tc.input, err) - } - if encryptedNew == "" { - t.Errorf("Encrypt returned empty string for input '%s'", tc.input) - } - - // New Decryption - decryptedNew, err := ec.Decrypt(encryptedNew) - if err != nil { - t.Errorf("Decrypt failed for input '%s': %v", tc.input, err) - } - - // Verify that the decrypted value matches the original input - if decryptedNew != tc.input { - t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input) - } - }) - } -} - -func TestPKCS5UnPadding(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - expectError bool - }{ - { - name: "Valid Padding", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), - expected: []byte("Hello, World!"), - }, - { - name: "Empty Input", - input: []byte{}, - expectError: true, - }, - { - name: "Padding Length Zero", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), - expectError: true, - }, - { - name: "Padding Length Exceeds Block Size", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), - expectError: true, - }, - { - name: "Padding Length Exceeds Input Length", - input: []byte{5, 5, 5}, - expectError: true, - }, - { - name: "Invalid Padding Bytes", - input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), - expectError: true, - }, - { - name: "Valid Single Byte Padding", - input: append([]byte("Hello, World!"), byte(1)), - expected: []byte("Hello, World!"), - }, - { - name: "Invalid Mixed Padding Bytes", - input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), - expectError: true, - }, - { - name: "Valid Full Block Padding", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("Hello, World!"), - }, - { - name: "Non-Padding Byte at End", - input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), - expectError: true, - }, - { - name: "Valid Padding with Different Text Length", - input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), - expected: []byte("Test"), - }, - { - name: "Padding Length Equal to Input Length", - input: bytes.Repeat([]byte{8}, 8), - expected: []byte{}, - }, - { - name: "Invalid Padding Length Zero (Again)", - input: append([]byte("Test"), byte(0)), - expectError: true, - }, - { - name: "Padding Length Greater Than Input", - input: []byte{10}, - expectError: true, - }, - { - name: "Input Length Not Multiple of Block Size", - input: append([]byte("Invalid Length"), byte(1)), - expected: []byte("Invalid Length"), - }, - { - name: "Valid Padding with Non-ASCII Characters", - input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), - expected: []byte("こんにちは"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs5UnPadding(tt.input) - if tt.expectError { - if err == nil { - t.Errorf("Expected error but got nil") - } - } else { - if err != nil { - t.Errorf("Did not expect error but got: %v", err) - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("Expected output %v, got %v", tt.expected, result) - } - } - }) - } -} diff --git a/management/server/activity/store/migration.go b/management/server/activity/store/migration.go index af19a34eb..d0f165d5f 100644 --- a/management/server/activity/store/migration.go +++ b/management/server/activity/store/migration.go @@ -10,9 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/migration" + "github.com/netbirdio/netbird/util/crypt" ) -func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error { +func migrate(ctx context.Context, crypt *crypt.FieldEncrypt, db *gorm.DB) error { migrations := getMigrations(ctx, crypt) for _, m := range migrations { @@ -26,7 +27,7 @@ func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error { type migrationFunc func(*gorm.DB) error -func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc { +func getMigrations(ctx context.Context, crypt *crypt.FieldEncrypt) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "") @@ -45,7 +46,7 @@ func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc { // migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using // legacy CBC encryption with a static IV to the new GCM encryption method. -func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error { +func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *crypt.FieldEncrypt) error { model := &activity.DeletedUser{} if !db.Migrator().HasTable(model) { @@ -80,7 +81,7 @@ func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *F return nil } -func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error { +func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *crypt.FieldEncrypt) error { var err error var decryptedEmail, decryptedName string diff --git a/management/server/activity/store/migration_test.go b/management/server/activity/store/migration_test.go index e3261d9fa..5c6f5ade8 100644 --- a/management/server/activity/store/migration_test.go +++ b/management/server/activity/store/migration_test.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -40,10 +41,10 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) { db := setupDatabase(t) - key, err := GenerateKey() + key, err := crypt.GenerateKey() require.NoError(t, err, "Failed to generate key") - crypt, err := NewFieldEncrypt(key) + crypt, err := crypt.NewFieldEncrypt(key) require.NoError(t, err, "Failed to initialize FieldEncrypt") t.Run("empty table, no migration required", func(t *testing.T) { diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index ffecb6b8f..db614d0cd 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -45,12 +46,12 @@ type eventWithNames struct { // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { db *gorm.DB - fieldEncrypt *FieldEncrypt + fieldEncrypt *crypt.FieldEncrypt } // NewSqlStore creates a new Store with an event table if not exists. func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) { - crypt, err := NewFieldEncrypt(encryptionKey) + fieldEncrypt, err := crypt.NewFieldEncrypt(encryptionKey) if err != nil { return nil, err @@ -61,7 +62,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St return nil, fmt.Errorf("initialize database: %w", err) } - if err = migrate(ctx, crypt, db); err != nil { + if err = migrate(ctx, fieldEncrypt, db); err != nil { return nil, fmt.Errorf("events database migration: %w", err) } @@ -72,7 +73,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St return &Store{ db: db, - fieldEncrypt: crypt, + fieldEncrypt: fieldEncrypt, }, nil } diff --git a/management/server/activity/store/sql_store_test.go b/management/server/activity/store/sql_store_test.go index 8c0d159df..d723f1623 100644 --- a/management/server/activity/store/sql_store_test.go +++ b/management/server/activity/store/sql_store_test.go @@ -9,11 +9,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/util/crypt" ) func TestNewSqlStore(t *testing.T) { dataDir := t.TempDir() - key, _ := GenerateKey() + key, _ := crypt.GenerateKey() store, err := NewSqlStore(context.Background(), dataDir, key) if err != nil { t.Fatal(err) diff --git a/management/server/user.go b/management/server/user.go index 1f38b749f..0a090d681 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -704,7 +704,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta) }) } @@ -718,7 +718,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) }) } @@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI addPeerRemovedEvent() } - meta := map[string]any{"name": targetUserInfo.Name, "email": targetUserInfo.Email, "created_at": targetUser.CreatedAt} + meta := map[string]any{"name": targetUserInfo.Name, "email": targetUserInfo.Email, "created_at": targetUser.CreatedAt, "issued": targetUser.Issued} am.StoreEvent(ctx, initiatorUserID, targetUser.Id, accountID, activity.UserDeleted, meta) return updateAccountPeers, nil diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index bc2d4d1be..523beb32b 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -154,9 +154,20 @@ func (s *SharedSocket) updateRouter() { } } -// LocalAddr returns an IPv4 address using the supplied port +// LocalAddr returns the local address, preferring IPv4 for backward compatibility. func (s *SharedSocket) LocalAddr() net.Addr { - // todo check impact on ipv6 discovery + if s.conn4 != nil { + return &net.UDPAddr{ + IP: net.IPv4zero, + Port: s.port, + } + } + if s.conn6 != nil { + return &net.UDPAddr{ + IP: net.IPv6zero, + Port: s.port, + } + } return &net.UDPAddr{ IP: net.IPv4zero, Port: s.port, diff --git a/util/crypt/crypt_test.go b/util/crypt/crypt_test.go new file mode 100644 index 000000000..143a4bbc2 --- /dev/null +++ b/util/crypt/crypt_test.go @@ -0,0 +1,139 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateKey(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + assert.NotEmpty(t, key) + + _, err = NewFieldEncrypt(key) + assert.NoError(t, err) +} + +func TestNewFieldEncrypt_InvalidKey(t *testing.T) { + tests := []struct { + name string + key string + }{ + {name: "invalid base64", key: "not-valid-base64!!!"}, + {name: "too short", key: "c2hvcnQ="}, + {name: "empty", key: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewFieldEncrypt(tt.key) + assert.Error(t, err) + }) + } +} + +func TestEncryptDecrypt(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + input string + }{ + {name: "Empty String", input: ""}, + {name: "Short String", input: "Hello"}, + {name: "String with Spaces", input: "Hello, World!"}, + {name: "Long String", input: "The quick brown fox jumps over the lazy dog."}, + {name: "Unicode Characters", input: "こんにちは世界"}, + {name: "Special Characters", input: "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + {name: "Numeric String", input: "1234567890"}, + {name: "Email Address", input: "user@example.com"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted, err := ec.Encrypt(tc.input) + require.NoError(t, err) + + decrypted, err := ec.Decrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, tc.input, decrypted) + }) + } +} + +func TestEncrypt_DifferentCiphertexts(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + plaintext := "same plaintext" + + // Encrypt the same plaintext multiple times + encrypted1, err := ec.Encrypt(plaintext) + require.NoError(t, err) + + encrypted2, err := ec.Encrypt(plaintext) + require.NoError(t, err) + + assert.NotEqual(t, encrypted1, encrypted2, "expected different ciphertexts for same plaintext (random nonce)") + + // Both should decrypt to the same plaintext + decrypted1, err := ec.Decrypt(encrypted1) + require.NoError(t, err) + + decrypted2, err := ec.Decrypt(encrypted2) + require.NoError(t, err) + + assert.Equal(t, plaintext, decrypted1) + assert.Equal(t, plaintext, decrypted2) +} + +func TestDecrypt_InvalidCiphertext(t *testing.T) { + key, err := GenerateKey() + assert.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + assert.NoError(t, err) + + tests := []struct { + name string + ciphertext string + }{ + {name: "invalid base64", ciphertext: "not-valid!!!"}, + {name: "too short", ciphertext: "c2hvcnQ="}, + {name: "corrupted", ciphertext: "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo="}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := ec.Decrypt(tt.ciphertext) + assert.Error(t, err) + assert.Empty(t, payload) + }) + } +} + +func TestDecrypt_WrongKey(t *testing.T) { + key1, _ := GenerateKey() + key2, _ := GenerateKey() + + ec1, _ := NewFieldEncrypt(key1) + ec2, _ := NewFieldEncrypt(key2) + + plaintext := "secret data" + encrypted, _ := ec1.Encrypt(plaintext) + + // Try to decrypt with wrong key + payload, err := ec2.Decrypt(encrypted) + assert.Error(t, err) + assert.Empty(t, payload) +} diff --git a/util/crypt/legacy.go b/util/crypt/legacy.go new file mode 100644 index 000000000..f84e6964f --- /dev/null +++ b/util/crypt/legacy.go @@ -0,0 +1,71 @@ +package crypt + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" +) + +// legacyIV is the static IV used by the legacy CBC encryption. +// Deprecated: This is kept only for backward compatibility with existing encrypted data. +var legacyIV = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} + +// LegacyEncrypt encrypts plaintext using AES-CBC with a static IV. +// Deprecated: Use Encrypt instead. This method is kept only for backward compatibility. +func (f *FieldEncrypt) LegacyEncrypt(plaintext string) string { + padded := pkcs5Padding([]byte(plaintext)) + ciphertext := make([]byte, len(padded)) + cbc := cipher.NewCBCEncrypter(f.block, legacyIV) + cbc.CryptBlocks(ciphertext, padded) + return base64.StdEncoding.EncodeToString(ciphertext) +} + +// LegacyDecrypt decrypts ciphertext that was encrypted using AES-CBC with a static IV. +// Deprecated: This method is kept only for backward compatibility with existing encrypted data. +func (f *FieldEncrypt) LegacyDecrypt(ciphertext string) (string, error) { + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("decode ciphertext: %w", err) + } + + cbc := cipher.NewCBCDecrypter(f.block, legacyIV) + cbc.CryptBlocks(data, data) + + plaintext, err := pkcs5UnPadding(data) + if err != nil { + return "", fmt.Errorf("unpad plaintext: %w", err) + } + + return string(plaintext), nil +} + +// pkcs5Padding adds PKCS#5 padding to the input. +func pkcs5Padding(data []byte) []byte { + padding := aes.BlockSize - len(data)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +// pkcs5UnPadding removes PKCS#5 padding from the input. +func pkcs5UnPadding(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, fmt.Errorf("input data is empty") + } + + paddingLen := int(data[length-1]) + if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > length { + return nil, fmt.Errorf("invalid padding size") + } + + // Verify that all padding bytes are the same + for i := 0; i < paddingLen; i++ { + if data[length-1-i] != byte(paddingLen) { + return nil, fmt.Errorf("invalid padding") + } + } + + return data[:length-paddingLen], nil +} diff --git a/util/crypt/legacy_test.go b/util/crypt/legacy_test.go new file mode 100644 index 000000000..09b75a71f --- /dev/null +++ b/util/crypt/legacy_test.go @@ -0,0 +1,164 @@ +package crypt + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLegacyEncryptDecrypt(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + encrypted := ec.LegacyEncrypt(testData) + assert.NotEmpty(t, encrypted) + + decrypted, err := ec.LegacyDecrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, testData, decrypted) +} + +func TestLegacyEncryptDecryptVariousInputs(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + input string + }{ + {name: "Empty String", input: ""}, + {name: "Short String", input: "Hello"}, + {name: "String with Spaces", input: "Hello, World!"}, + {name: "Long String", input: "The quick brown fox jumps over the lazy dog."}, + {name: "Unicode Characters", input: "こんにちは世界"}, + {name: "Special Characters", input: "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + {name: "Numeric String", input: "1234567890"}, + {name: "Repeated Characters", input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + {name: "Multi-block String", input: "This is a longer string that will span multiple blocks in the encryption algorithm."}, + {name: "Non-ASCII and ASCII Mix", input: "Hello 世界 123"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted := ec.LegacyEncrypt(tc.input) + assert.NotEmpty(t, encrypted) + + decrypted, err := ec.LegacyDecrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, tc.input, decrypted) + }) + } +} + +func TestPKCS5UnPadding(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + expectError bool + }{ + { + name: "Valid Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), + expected: []byte("Hello, World!"), + }, + { + name: "Empty Input", + input: []byte{}, + expectError: true, + }, + { + name: "Padding Length Zero", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Block Size", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Input Length", + input: []byte{5, 5, 5}, + expectError: true, + }, + { + name: "Invalid Padding Bytes", + input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), + expectError: true, + }, + { + name: "Valid Single Byte Padding", + input: append([]byte("Hello, World!"), byte(1)), + expected: []byte("Hello, World!"), + }, + { + name: "Invalid Mixed Padding Bytes", + input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), + expectError: true, + }, + { + name: "Valid Full Block Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("Hello, World!"), + }, + { + name: "Non-Padding Byte at End", + input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), + expectError: true, + }, + { + name: "Valid Padding with Different Text Length", + input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), + expected: []byte("Test"), + }, + { + name: "Padding Length Equal to Input Length", + input: bytes.Repeat([]byte{8}, 8), + expected: []byte{}, + }, + { + name: "Invalid Padding Length Zero (Again)", + input: append([]byte("Test"), byte(0)), + expectError: true, + }, + { + name: "Padding Length Greater Than Input", + input: []byte{10}, + expectError: true, + }, + { + name: "Input Length Not Multiple of Block Size", + input: append([]byte("Invalid Length"), byte(1)), + expected: []byte("Invalid Length"), + }, + { + name: "Valid Padding with Non-ASCII Characters", + input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), + expected: []byte("こんにちは"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs5UnPadding(tt.input) + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +}