mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
Compare commits
7 Commits
transparen
...
fix/androi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f01c1eea6a | ||
|
|
a35ecf9aa8 | ||
|
|
1d792f0b53 | ||
|
|
b3178255c0 | ||
|
|
4eed459f27 | ||
|
|
13539543af | ||
|
|
7483fec048 |
@@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -26,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
types "github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -69,6 +72,8 @@ type Client struct {
|
|||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
cacheDir string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -93,6 +98,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
cacheDir := platformFiles.CacheDir()
|
||||||
|
|
||||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -124,8 +130,10 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
|
c.config = cfg
|
||||||
|
c.cacheDir = cacheDir
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -135,6 +143,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
cacheDir := platformFiles.CacheDir()
|
||||||
|
|
||||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -157,8 +166,10 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
|
c.config = cfg
|
||||||
|
c.cacheDir = cacheDir
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -185,6 +196,74 @@ func (c *Client) RenewTun(fd int) error {
|
|||||||
return e.RenewTun(fd)
|
return e.RenewTun(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
||||||
|
// It works both with and without a running engine.
|
||||||
|
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
||||||
|
cfg := c.config
|
||||||
|
cacheDir := c.cacheDir
|
||||||
|
|
||||||
|
// If the engine hasn't been started, load config from disk
|
||||||
|
if cfg == nil {
|
||||||
|
var err error
|
||||||
|
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
|
ConfigPath: platformFiles.ConfigurationFilePath(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
cacheDir = platformFiles.CacheDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := debug.GeneratorDependencies{
|
||||||
|
InternalConfig: cfg,
|
||||||
|
StatusRecorder: c.recorder,
|
||||||
|
TempDir: cacheDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.connectClient != nil {
|
||||||
|
resp, err := c.connectClient.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("get latest sync response: %v", err)
|
||||||
|
}
|
||||||
|
deps.SyncResponse = resp
|
||||||
|
|
||||||
|
if e := c.connectClient.Engine(); e != nil {
|
||||||
|
if cm := e.GetClientMetrics(); cm != nil {
|
||||||
|
deps.ClientMetrics = cm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
deps,
|
||||||
|
debug.BundleConfig{
|
||||||
|
Anonymize: anonymize,
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("debug bundle uploaded with key %s", key)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ package android
|
|||||||
type PlatformFiles interface {
|
type PlatformFiles interface {
|
||||||
ConfigurationFilePath() string
|
ConfigurationFilePath() string
|
||||||
StateFilePath() string
|
StateFilePath() string
|
||||||
|
CacheDir() string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
|||||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||||
|
// needs a device filter for DNS interception hooks. Install a minimal
|
||||||
|
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||||
|
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||||
|
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
37
client/firewall/uspfilter/common/hooks.go
Normal file
37
client/firewall/uspfilter/common/hooks.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketHook stores a registered hook for a specific IP:port.
|
||||||
|
type PacketHook struct {
|
||||||
|
IP netip.Addr
|
||||||
|
Port uint16
|
||||||
|
Fn func([]byte) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookMatches checks if a packet's destination matches the hook and invokes it.
|
||||||
|
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||||
|
if h == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.IP == dstIP && h.Port == dport {
|
||||||
|
return h.Fn(packetData)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHook atomically stores a hook, handling nil removal.
|
||||||
|
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
if hook == nil {
|
||||||
|
ptr.Store(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ptr.Store(&PacketHook{
|
||||||
|
IP: ip,
|
||||||
|
Port: dPort,
|
||||||
|
Fn: hook,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -142,15 +142,8 @@ type Manager struct {
|
|||||||
mssClampEnabled bool
|
mssClampEnabled bool
|
||||||
|
|
||||||
// Only one hook per protocol is supported. Outbound direction only.
|
// Only one hook per protocol is supported. Outbound direction only.
|
||||||
udpHookOut atomic.Pointer[packetHook]
|
udpHookOut atomic.Pointer[common.PacketHook]
|
||||||
tcpHookOut atomic.Pointer[packetHook]
|
tcpHookOut atomic.Pointer[common.PacketHook]
|
||||||
}
|
|
||||||
|
|
||||||
// packetHook stores a registered hook for a specific IP:port.
|
|
||||||
type packetHook struct {
|
|
||||||
ip netip.Addr
|
|
||||||
port uint16
|
|
||||||
fn func([]byte) bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||||
}
|
|
||||||
|
|
||||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
|
||||||
if h == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if h.ip == dstIP && h.port == dport {
|
|
||||||
return h.fn(packetData)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterInbound implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
@@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
|
|
||||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||||
if hook == nil {
|
common.SetHook(&m.udpHookOut, ip, dPort, hook)
|
||||||
m.udpHookOut.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.udpHookOut.Store(&packetHook{
|
|
||||||
ip: ip,
|
|
||||||
port: dPort,
|
|
||||||
fn: hook,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||||
if hook == nil {
|
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
|
||||||
m.tcpHookOut.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.tcpHookOut.Store(&packetHook{
|
|
||||||
ip: ip,
|
|
||||||
port: dPort,
|
|
||||||
fn: hook,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogLevel sets the log level for the firewall manager
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
|||||||
@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
h := manager.udpHookOut.Load()
|
h := manager.udpHookOut.Load()
|
||||||
require.NotNil(t, h)
|
require.NotNil(t, h)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||||
assert.Equal(t, uint16(8000), h.port)
|
assert.Equal(t, uint16(8000), h.Port)
|
||||||
assert.True(t, h.fn(nil))
|
assert.True(t, h.Fn(nil))
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
|
|
||||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||||
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
h := manager.tcpHookOut.Load()
|
h := manager.tcpHookOut.Load()
|
||||||
require.NotNil(t, h)
|
require.NotNil(t, h)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||||
assert.Equal(t, uint16(53), h.port)
|
assert.Equal(t, uint16(53), h.Port)
|
||||||
assert.True(t, h.fn(nil))
|
assert.True(t, h.Fn(nil))
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
|
|
||||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||||
|
|||||||
90
client/firewall/uspfilter/hooks_filter.go
Normal file
90
client/firewall/uspfilter/hooks_filter.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4HeaderMinLen = 20
|
||||||
|
ipv4ProtoOffset = 9
|
||||||
|
ipv4FlagsOffset = 6
|
||||||
|
ipv4DstOffset = 16
|
||||||
|
ipProtoUDP = 17
|
||||||
|
ipProtoTCP = 6
|
||||||
|
ipv4FragOffMask = 0x1fff
|
||||||
|
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
|
||||||
|
dstPortOffset = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
|
||||||
|
// It is installed on the WireGuard interface when the userspace bind is active
|
||||||
|
// but a full firewall filter (Manager) is not needed because a native kernel
|
||||||
|
// firewall (nftables/iptables) handles packet filtering.
|
||||||
|
type HooksFilter struct {
|
||||||
|
udpHook atomic.Pointer[common.PacketHook]
|
||||||
|
tcpHook atomic.Pointer[common.PacketHook]
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ device.PacketFilter = (*HooksFilter)(nil)
|
||||||
|
|
||||||
|
// FilterOutbound checks outbound packets for DNS hook matches.
|
||||||
|
// Only IPv4 packets matching the registered hook IP:port are intercepted.
|
||||||
|
// IPv6 and non-IP packets pass through unconditionally.
|
||||||
|
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
|
||||||
|
if len(packetData) < ipv4HeaderMinLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process IPv4 packets, let everything else pass through.
|
||||||
|
if packetData[0]>>4 != 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ihl := int(packetData[0]&0x0f) * 4
|
||||||
|
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip non-first fragments: they don't carry L4 headers.
|
||||||
|
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
|
||||||
|
if flagsAndOffset&ipv4FragOffMask != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := packetData[ipv4ProtoOffset]
|
||||||
|
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case ipProtoUDP:
|
||||||
|
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
|
||||||
|
case ipProtoTCP:
|
||||||
|
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterInbound allows all inbound packets (native firewall handles filtering).
|
||||||
|
func (f *HooksFilter) FilterInbound([]byte, int) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUDPPacketHook registers the UDP packet hook.
|
||||||
|
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
common.SetHook(&f.udpHook, ip, dPort, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTCPPacketHook registers the TCP packet hook.
|
||||||
|
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
common.SetHook(&f.tcpHook, ip, dPort, hook)
|
||||||
|
}
|
||||||
@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
|
cacheDir string,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
|
TempDir: cacheDir,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
@@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
engineConfig.TempDir = mobileDependency.TempDir
|
||||||
|
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"slices"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,7 +30,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -234,6 +232,7 @@ type BundleGenerator struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logPath string
|
logPath string
|
||||||
|
tempDir string
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
clientMetrics MetricsExporter
|
clientMetrics MetricsExporter
|
||||||
@@ -256,6 +255,7 @@ type GeneratorDependencies struct {
|
|||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogPath string
|
LogPath string
|
||||||
|
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
ClientMetrics MetricsExporter
|
ClientMetrics MetricsExporter
|
||||||
@@ -275,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
|
tempDir: deps.TempDir,
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
clientMetrics: deps.ClientMetrics,
|
clientMetrics: deps.ClientMetrics,
|
||||||
@@ -287,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
|
|
||||||
// Generate creates a debug bundle and returns the location.
|
// Generate creates a debug bundle and returns the location.
|
||||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create zip file: %w", err)
|
return "", fmt.Errorf("create zip file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -373,15 +374,8 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
if err := g.addPlatformLog(); err != nil {
|
||||||
if err := g.addLogfile(); err != nil {
|
log.Errorf("failed to add logs to debug bundle: %v", err)
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
log.Errorf("failed to add systemd logs: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addUpdateLogs(); err != nil {
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
|
|||||||
34
client/internal/debug/debug_android.go
Normal file
34
client/internal/debug/debug_android.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addPlatformLog() error {
|
||||||
|
cmd := exec.Command("/system/bin/logcat", "-d")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("run logcat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var logReader *strings.Reader
|
||||||
|
if g.anonymize {
|
||||||
|
anonymized := g.anonymizer.AnonymizeString(string(out))
|
||||||
|
logReader = strings.NewReader(anonymized)
|
||||||
|
} else {
|
||||||
|
logReader = strings.NewReader(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add logcat to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("added logcat output to debug bundle (%d bytes)", len(out))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
25
client/internal/debug/debug_nonandroid.go
Normal file
25
client/internal/debug/debug_nonandroid.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addPlatformLog() error {
|
||||||
|
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||||
|
if err := g.addLogfile(); err != nil {
|
||||||
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -140,6 +140,7 @@ type EngineConfig struct {
|
|||||||
ProfileConfig *profilemanager.Config
|
ProfileConfig *profilemanager.Config
|
||||||
|
|
||||||
LogPath string
|
LogPath string
|
||||||
|
TempDir string
|
||||||
}
|
}
|
||||||
|
|
||||||
// EngineServices holds the external service dependencies required by the Engine.
|
// EngineServices holds the external service dependencies required by the Engine.
|
||||||
@@ -1095,6 +1096,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
|||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: e.config.LogPath,
|
LogPath: e.config.LogPath,
|
||||||
|
TempDir: e.config.TempDir,
|
||||||
ClientMetrics: e.clientMetrics,
|
ClientMetrics: e.clientMetrics,
|
||||||
RefreshStatus: func() {
|
RefreshStatus: func() {
|
||||||
e.RunHealthProbes(true)
|
e.RunHealthProbes(true)
|
||||||
|
|||||||
@@ -22,4 +22,8 @@ type MobileDependency struct {
|
|||||||
DnsManager dns.IosDnsManager
|
DnsManager dns.IosDnsManager
|
||||||
FileDescriptor int32
|
FileDescriptor int32
|
||||||
StateFilePath string
|
StateFilePath string
|
||||||
|
|
||||||
|
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
|
||||||
|
// On Android, this should be set to the app's cache directory.
|
||||||
|
TempDir string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
}
|
}
|
||||||
cr = append(cr, fakeIPRoute)
|
cr = append(cr, fakeIPRoute)
|
||||||
|
m.notifier.SetFakeIPRoute(fakeIPRoute)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
initialRoutes []*route.Route
|
initialRoutes []*route.Route
|
||||||
currentRoutes []*route.Route
|
currentRoutes []*route.Route
|
||||||
|
fakeIPRoute *route.Route
|
||||||
|
|
||||||
listener listener.NetworkChangeListener
|
listener listener.NetworkChangeListener
|
||||||
listenerMux sync.Mutex
|
listenerMux sync.Mutex
|
||||||
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|||||||
n.listener = listener
|
n.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
|
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
|
||||||
// and a separate comparison set (without fake IP blocks) for diff detection.
|
|
||||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||||
n.initialRoutes = filterStatic(initialRoutes)
|
n.initialRoutes = filterStatic(initialRoutes)
|
||||||
n.currentRoutes = filterStatic(routesForComparison)
|
n.currentRoutes = filterStatic(routesForComparison)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
|
||||||
|
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
|
||||||
|
n.fakeIPRoute = r
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
var newRoutes []*route.Route
|
var newRoutes []*route.Route
|
||||||
for _, routes := range idMap {
|
for _, routes := range idMap {
|
||||||
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
allRoutes := slices.Clone(n.currentRoutes)
|
allRoutes := slices.Clone(n.currentRoutes)
|
||||||
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
|
if n.fakeIPRoute != nil {
|
||||||
|
allRoutes = append(allRoutes, n.fakeIPRoute)
|
||||||
|
}
|
||||||
|
|
||||||
routeStrings := n.routesToStrings(allRoutes)
|
routeStrings := n.routesToStrings(allRoutes)
|
||||||
sort.Strings(routeStrings)
|
sort.Strings(routeStrings)
|
||||||
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
|
|||||||
}(n.listener)
|
}(n.listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extraInitialRoutes returns initialRoutes whose network prefix is absent
|
|
||||||
// from currentRoutes (e.g. the fake IP block added at setup time).
|
|
||||||
func (n *Notifier) extraInitialRoutes() []*route.Route {
|
|
||||||
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
|
|
||||||
for _, r := range n.currentRoutes {
|
|
||||||
currentNets[r.Network] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var extra []*route.Route
|
|
||||||
for _, r := range n.initialRoutes {
|
|
||||||
if _, ok := currentNets[r.Network]; !ok {
|
|
||||||
extra = append(extra, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return extra
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterStatic(routes []*route.Route) []*route.Route {
|
func filterStatic(routes []*route.Route) []*route.Route {
|
||||||
out := make([]*route.Route, 0, len(routes))
|
out := make([]*route.Route, 0, len(routes))
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
|||||||
// iOS doesn't care about initial routes
|
// iOS doesn't care about initial routes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||||
|
// Not used on iOS
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||||
// Not used on iOS
|
// Not used on iOS
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
|||||||
// Not used on non-mobile platforms
|
// Not used on non-mobile platforms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
// Not used on non-mobile platforms
|
// Not used on non-mobile platforms
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -26,11 +25,22 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrClientClosed = errors.New("client is closed")
|
||||||
|
|
||||||
|
// minHealthyDuration is the minimum time a stream must survive before a failure
|
||||||
|
// resets the backoff timer. Streams that fail faster are considered unhealthy and
|
||||||
|
// should not reset backoff, so that MaxElapsedTime can eventually stop retries.
|
||||||
|
const minHealthyDuration = 5 * time.Second
|
||||||
|
|
||||||
type GRPCClient struct {
|
type GRPCClient struct {
|
||||||
realClient proto.FlowServiceClient
|
realClient proto.FlowServiceClient
|
||||||
clientConn *grpc.ClientConn
|
clientConn *grpc.ClientConn
|
||||||
stream proto.FlowService_EventsClient
|
stream proto.FlowService_EventsClient
|
||||||
streamMu sync.Mutex
|
target string
|
||||||
|
opts []grpc.DialOption
|
||||||
|
closed bool // prevent creating conn in the middle of the Close
|
||||||
|
receiving bool // prevent concurrent Receive calls
|
||||||
|
mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
|
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
|
||||||
@@ -65,7 +75,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
|||||||
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
|
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
|
||||||
)
|
)
|
||||||
|
|
||||||
conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...)
|
target := parsedURL.Host
|
||||||
|
conn, err := grpc.NewClient(target, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating new grpc client: %w", err)
|
return nil, fmt.Errorf("creating new grpc client: %w", err)
|
||||||
}
|
}
|
||||||
@@ -73,30 +84,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
|||||||
return &GRPCClient{
|
return &GRPCClient{
|
||||||
realClient: proto.NewFlowServiceClient(conn),
|
realClient: proto.NewFlowServiceClient(conn),
|
||||||
clientConn: conn,
|
clientConn: conn,
|
||||||
|
target: target,
|
||||||
|
opts: opts,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GRPCClient) Close() error {
|
func (c *GRPCClient) Close() error {
|
||||||
c.streamMu.Lock()
|
c.mu.Lock()
|
||||||
defer c.streamMu.Unlock()
|
c.closed = true
|
||||||
|
|
||||||
c.stream = nil
|
c.stream = nil
|
||||||
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
conn := c.clientConn
|
||||||
|
c.clientConn = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
return fmt.Errorf("close client connection: %w", err)
|
return fmt.Errorf("close client connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
stream := c.stream
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if stream == nil {
|
||||||
|
return errors.New("stream not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stream.Send(event); err != nil {
|
||||||
|
return fmt.Errorf("send flow event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.receiving {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return errors.New("concurrent Receive calls are not supported")
|
||||||
|
}
|
||||||
|
c.receiving = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.receiving = false
|
||||||
|
c.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
backOff := defaultBackoff(ctx, interval)
|
backOff := defaultBackoff(ctx, interval)
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
|
stream, err := c.establishStream(ctx)
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
if err != nil {
|
||||||
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
|
log.Errorf("failed to establish flow stream, retrying: %v", err)
|
||||||
|
return c.handleRetryableError(err, time.Time{}, backOff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
streamStart := time.Now()
|
||||||
|
|
||||||
|
if err := c.receive(stream, msgHandler); err != nil {
|
||||||
log.Errorf("receive failed: %v", err)
|
log.Errorf("receive failed: %v", err)
|
||||||
return fmt.Errorf("receive: %w", err)
|
return c.handleRetryableError(err, streamStart, backOff)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -108,37 +162,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
|
// handleRetryableError resets the backoff timer if the stream was healthy long
|
||||||
if c.clientConn.GetState() == connectivity.Shutdown {
|
// enough and recreates the underlying ClientConn so that gRPC's internal
|
||||||
return errors.New("connection to flow receiver has been shut down")
|
// subchannel backoff does not accumulate and compete with our own retry timer.
|
||||||
|
// A zero streamStart means the stream was never established.
|
||||||
|
func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error {
|
||||||
|
if isContextDone(err) {
|
||||||
|
return backoff.Permanent(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
|
var permErr *backoff.PermanentError
|
||||||
if err != nil {
|
if errors.As(err, &permErr) {
|
||||||
return fmt.Errorf("create event stream: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = stream.Send(&proto.FlowEvent{IsInitiator: true})
|
// Reset the backoff so the next retry starts with a short delay instead of
|
||||||
|
// continuing the already-elapsed timer. Only do this if the stream was healthy
|
||||||
|
// long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime.
|
||||||
|
if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration {
|
||||||
|
backOff.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
if recreateErr := c.recreateConnection(); recreateErr != nil {
|
||||||
|
log.Errorf("recreate connection: %v", recreateErr)
|
||||||
|
return recreateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("connection recreated, retrying stream")
|
||||||
|
return fmt.Errorf("retrying after error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) recreateConnection() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return backoff.Permanent(ErrClientClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(c.target, c.opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err)
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("create new connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
old := c.clientConn
|
||||||
|
c.clientConn = conn
|
||||||
|
c.realClient = proto.NewFlowServiceClient(conn)
|
||||||
|
c.stream = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
_ = old.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, backoff.Permanent(ErrClientClosed)
|
||||||
|
}
|
||||||
|
cl := c.realClient
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// open stream outside the lock — blocking operation
|
||||||
|
stream, err := cl.Events(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create event stream: %w", err)
|
||||||
|
}
|
||||||
|
streamReady := false
|
||||||
|
defer func() {
|
||||||
|
if !streamReady {
|
||||||
|
_ = stream.CloseSend()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil {
|
||||||
|
return nil, fmt.Errorf("send initiator: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = checkHeader(stream); err != nil {
|
if err = checkHeader(stream); err != nil {
|
||||||
return fmt.Errorf("check header: %w", err)
|
return nil, fmt.Errorf("check header: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.streamMu.Lock()
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, backoff.Permanent(ErrClientClosed)
|
||||||
|
}
|
||||||
c.stream = stream
|
c.stream = stream
|
||||||
c.streamMu.Unlock()
|
c.mu.Unlock()
|
||||||
|
streamReady = true
|
||||||
|
|
||||||
return c.receive(stream, msgHandler)
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
|
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||||
for {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("receive from stream: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg.IsInitiator {
|
if msg.IsInitiator {
|
||||||
@@ -169,7 +292,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
|
|||||||
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
|
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
|
||||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 800 * time.Millisecond,
|
InitialInterval: 800 * time.Millisecond,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 0.5,
|
||||||
Multiplier: 1.7,
|
Multiplier: 1.7,
|
||||||
MaxInterval: interval / 2,
|
MaxInterval: interval / 2,
|
||||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||||
@@ -178,18 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
|
|||||||
}, ctx)
|
}, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
func isContextDone(err error) bool {
|
||||||
c.streamMu.Lock()
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
stream := c.stream
|
return true
|
||||||
c.streamMu.Unlock()
|
|
||||||
|
|
||||||
if stream == nil {
|
|
||||||
return errors.New("stream not initialized")
|
|
||||||
}
|
}
|
||||||
|
if s, ok := status.FromError(err); ok {
|
||||||
if err := stream.Send(event); err != nil {
|
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
|
||||||
return fmt.Errorf("send flow event: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ package client_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,6 +14,8 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
flow "github.com/netbirdio/netbird/flow/client"
|
flow "github.com/netbirdio/netbird/flow/client"
|
||||||
"github.com/netbirdio/netbird/flow/proto"
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
@@ -22,17 +27,85 @@ type testServer struct {
|
|||||||
acks chan *proto.FlowEventAck
|
acks chan *proto.FlowEventAck
|
||||||
grpcSrv *grpc.Server
|
grpcSrv *grpc.Server
|
||||||
addr string
|
addr string
|
||||||
|
listener *connTrackListener
|
||||||
|
closeStream chan struct{} // signal server to close the stream
|
||||||
|
handlerDone chan struct{} // signaled each time Events() exits
|
||||||
|
handlerStarted chan struct{} // signaled each time Events() begins
|
||||||
|
}
|
||||||
|
|
||||||
|
// connTrackListener wraps a net.Listener to track accepted connections
|
||||||
|
// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM.
|
||||||
|
type connTrackListener struct {
|
||||||
|
net.Listener
|
||||||
|
mu sync.Mutex
|
||||||
|
conns []net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *connTrackListener) Accept() (net.Conn, error) {
|
||||||
|
c, err := l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l.mu.Lock()
|
||||||
|
l.conns = append(l.conns, c)
|
||||||
|
l.mu.Unlock()
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR
|
||||||
|
// (error code 0x1) on every tracked connection. This produces the exact error:
|
||||||
|
//
|
||||||
|
// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR
|
||||||
|
//
|
||||||
|
// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload):
|
||||||
|
//
|
||||||
|
// Length (3 bytes): 0x000004
|
||||||
|
// Type (1 byte): 0x03 (RST_STREAM)
|
||||||
|
// Flags (1 byte): 0x00
|
||||||
|
// Stream ID (4 bytes): target stream (must have bit 31 clear)
|
||||||
|
// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR)
|
||||||
|
func (l *connTrackListener) connCount() int {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
return len(l.conns)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *connTrackListener) sendRSTStream(streamID uint32) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
frame := make([]byte, 13) // 9-byte header + 4-byte payload
|
||||||
|
// Length = 4 (3 bytes, big-endian)
|
||||||
|
frame[0], frame[1], frame[2] = 0, 0, 4
|
||||||
|
// Type = RST_STREAM (0x03)
|
||||||
|
frame[3] = 0x03
|
||||||
|
// Flags = 0
|
||||||
|
frame[4] = 0x00
|
||||||
|
// Stream ID (4 bytes, big-endian, bit 31 reserved = 0)
|
||||||
|
binary.BigEndian.PutUint32(frame[5:9], streamID)
|
||||||
|
// Error Code = PROTOCOL_ERROR (0x1)
|
||||||
|
binary.BigEndian.PutUint32(frame[9:13], 0x1)
|
||||||
|
|
||||||
|
for _, c := range l.conns {
|
||||||
|
_, _ = c.Write(frame)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServer(t *testing.T) *testServer {
|
func newTestServer(t *testing.T) *testServer {
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
rawListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listener := &connTrackListener{Listener: rawListener}
|
||||||
|
|
||||||
s := &testServer{
|
s := &testServer{
|
||||||
events: make(chan *proto.FlowEvent, 100),
|
events: make(chan *proto.FlowEvent, 100),
|
||||||
acks: make(chan *proto.FlowEventAck, 100),
|
acks: make(chan *proto.FlowEventAck, 100),
|
||||||
grpcSrv: grpc.NewServer(),
|
grpcSrv: grpc.NewServer(),
|
||||||
addr: listener.Addr().String(),
|
addr: rawListener.Addr().String(),
|
||||||
|
listener: listener,
|
||||||
|
closeStream: make(chan struct{}, 1),
|
||||||
|
handlerDone: make(chan struct{}, 10),
|
||||||
|
handlerStarted: make(chan struct{}, 10),
|
||||||
}
|
}
|
||||||
|
|
||||||
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||||
@@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||||
|
defer func() {
|
||||||
|
select {
|
||||||
|
case s.handlerDone <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.handlerStarted <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(stream.Context())
|
ctx, cancel := context.WithCancel(stream.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
|||||||
if err := stream.Send(ack); err != nil {
|
if err := stream.Send(ack); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case <-s.closeStream:
|
||||||
|
return status.Errorf(codes.Internal, "server closing stream")
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
@@ -110,16 +197,13 @@ func TestReceive(t *testing.T) {
|
|||||||
assert.NoError(t, err, "failed to close flow")
|
assert.NoError(t, err, "failed to close flow")
|
||||||
})
|
})
|
||||||
|
|
||||||
receivedAcks := make(map[string]bool)
|
var ackCount atomic.Int32
|
||||||
receiveDone := make(chan struct{})
|
receiveDone := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||||
id := string(msg.EventId)
|
if ackCount.Add(1) >= 3 {
|
||||||
receivedAcks[id] = true
|
|
||||||
|
|
||||||
if len(receivedAcks) >= 3 {
|
|
||||||
close(receiveDone)
|
close(receiveDone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -130,7 +214,11 @@ func TestReceive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
select {
|
||||||
|
case <-server.handlerStarted:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for stream to be established")
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
eventID := uuid.New().String()
|
eventID := uuid.New().String()
|
||||||
@@ -153,7 +241,7 @@ func TestReceive(t *testing.T) {
|
|||||||
t.Fatal("timeout waiting for acks to be processed")
|
t.Fatal("timeout waiting for acks to be processed")
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, 3, len(receivedAcks))
|
assert.Equal(t, int32(3), ackCount.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReceive_ContextCancellation(t *testing.T) {
|
func TestReceive_ContextCancellation(t *testing.T) {
|
||||||
@@ -254,3 +342,195 @@ func TestSend(t *testing.T) {
|
|||||||
t.Fatal("timeout waiting for ack to be received by flow")
|
t.Fatal("timeout waiting for ack to be received by flow")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewClient_PermanentClose(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = client.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
require.ErrorIs(t, err, flow.ErrClientClosed)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Receive did not return after Close — stuck in retry loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClient_CloseVerify(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
closeDone := make(chan struct{}, 1)
|
||||||
|
go func() {
|
||||||
|
_ = client.Close()
|
||||||
|
closeDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
require.Error(t, err)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Receive did not return after Close — stuck in retry loop")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-closeDone:
|
||||||
|
return
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Close did not return — blocked in retry loop")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClose_WhileReceiving(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background() // no timeout — intentional
|
||||||
|
receiveDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
close(receiveDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for the server-side handler to confirm the stream is established.
|
||||||
|
select {
|
||||||
|
case <-server.handlerStarted:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for stream to be established")
|
||||||
|
}
|
||||||
|
|
||||||
|
closeDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = client.Close()
|
||||||
|
close(closeDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-closeDone:
|
||||||
|
// Close returned — good
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Close blocked forever — Receive stuck in retry loop")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-receiveDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Receive did not exit after Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := client.Close()
|
||||||
|
assert.NoError(t, err, "failed to close flow")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Track acks received before and after server-side stream close
|
||||||
|
var ackCount atomic.Int32
|
||||||
|
receivedFirst := make(chan struct{})
|
||||||
|
receivedAfterReconnect := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
|
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n := ackCount.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
close(receivedFirst)
|
||||||
|
}
|
||||||
|
if n == 2 {
|
||||||
|
close(receivedAfterReconnect)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
t.Logf("receive error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for stream to be established, then send first ack
|
||||||
|
select {
|
||||||
|
case <-server.handlerStarted:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for stream to be established")
|
||||||
|
}
|
||||||
|
server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-receivedFirst:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for first ack")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot connection count before injecting the fault.
|
||||||
|
connsBefore := server.listener.connCount()
|
||||||
|
|
||||||
|
// Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection.
|
||||||
|
// gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated).
|
||||||
|
// Stream ID 1 is the client's first stream (our Events bidi stream).
|
||||||
|
// This produces the exact error the client sees in production:
|
||||||
|
// "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR"
|
||||||
|
server.listener.sendRSTStream(1)
|
||||||
|
|
||||||
|
// Wait for the old Events() handler to fully exit so it can no longer
|
||||||
|
// drain s.acks and drop our injected ack on a broken stream.
|
||||||
|
select {
|
||||||
|
case <-server.handlerDone:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("old Events() handler did not exit after RST_STREAM")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return server.listener.connCount() > connsBefore
|
||||||
|
}, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")
|
||||||
|
|
||||||
|
server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-receivedAfterReconnect:
|
||||||
|
// Client successfully reconnected and received ack after server-side stream close
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close")
|
||||||
|
assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user