mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Compare commits
3 Commits
transparen
...
add-slack-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
91a6577182 | ||
|
|
13539543af | ||
|
|
7483fec048 |
@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
cr = append(cr, fakeIPRoute)
|
||||
m.notifier.SetFakeIPRoute(fakeIPRoute)
|
||||
}
|
||||
|
||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
fakeIPRoute *route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
|
||||
// and a separate comparison set (without fake IP blocks) for diff detection.
|
||||
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
|
||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||
n.initialRoutes = filterStatic(initialRoutes)
|
||||
n.currentRoutes = filterStatic(routesForComparison)
|
||||
}
|
||||
|
||||
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
|
||||
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
|
||||
n.fakeIPRoute = r
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
var newRoutes []*route.Route
|
||||
for _, routes := range idMap {
|
||||
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
|
||||
}
|
||||
|
||||
allRoutes := slices.Clone(n.currentRoutes)
|
||||
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
|
||||
if n.fakeIPRoute != nil {
|
||||
allRoutes = append(allRoutes, n.fakeIPRoute)
|
||||
}
|
||||
|
||||
routeStrings := n.routesToStrings(allRoutes)
|
||||
sort.Strings(routeStrings)
|
||||
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
// extraInitialRoutes returns initialRoutes whose network prefix is absent
|
||||
// from currentRoutes (e.g. the fake IP block added at setup time).
|
||||
func (n *Notifier) extraInitialRoutes() []*route.Route {
|
||||
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
|
||||
for _, r := range n.currentRoutes {
|
||||
currentNets[r.Network] = struct{}{}
|
||||
}
|
||||
|
||||
var extra []*route.Route
|
||||
for _, r := range n.initialRoutes {
|
||||
if _, ok := currentNets[r.Network]; !ok {
|
||||
extra = append(extra, r)
|
||||
}
|
||||
}
|
||||
return extra
|
||||
}
|
||||
|
||||
func filterStatic(routes []*route.Route) []*route.Route {
|
||||
out := make([]*route.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
|
||||
@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// iOS doesn't care about initial routes
|
||||
}
|
||||
|
||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -26,11 +25,22 @@ import (
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
var ErrClientClosed = errors.New("client is closed")
|
||||
|
||||
// minHealthyDuration is the minimum time a stream must survive before a failure
|
||||
// resets the backoff timer. Streams that fail faster are considered unhealthy and
|
||||
// should not reset backoff, so that MaxElapsedTime can eventually stop retries.
|
||||
const minHealthyDuration = 5 * time.Second
|
||||
|
||||
type GRPCClient struct {
|
||||
realClient proto.FlowServiceClient
|
||||
clientConn *grpc.ClientConn
|
||||
stream proto.FlowService_EventsClient
|
||||
streamMu sync.Mutex
|
||||
target string
|
||||
opts []grpc.DialOption
|
||||
closed bool // prevent creating conn in the middle of the Close
|
||||
receiving bool // prevent concurrent Receive calls
|
||||
mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving
|
||||
}
|
||||
|
||||
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
|
||||
@@ -65,7 +75,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
|
||||
)
|
||||
|
||||
conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...)
|
||||
target := parsedURL.Host
|
||||
conn, err := grpc.NewClient(target, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating new grpc client: %w", err)
|
||||
}
|
||||
@@ -73,30 +84,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
return &GRPCClient{
|
||||
realClient: proto.NewFlowServiceClient(conn),
|
||||
clientConn: conn,
|
||||
target: target,
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Close() error {
|
||||
c.streamMu.Lock()
|
||||
defer c.streamMu.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
c.closed = true
|
||||
c.stream = nil
|
||||
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
conn := c.clientConn
|
||||
c.clientConn = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("close client connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||
c.mu.Lock()
|
||||
stream := c.stream
|
||||
c.mu.Unlock()
|
||||
|
||||
if stream == nil {
|
||||
return errors.New("stream not initialized")
|
||||
}
|
||||
|
||||
if err := stream.Send(event); err != nil {
|
||||
return fmt.Errorf("send flow event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
c.mu.Lock()
|
||||
if c.receiving {
|
||||
c.mu.Unlock()
|
||||
return errors.New("concurrent Receive calls are not supported")
|
||||
}
|
||||
c.receiving = true
|
||||
c.mu.Unlock()
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
c.receiving = false
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
|
||||
backOff := defaultBackoff(ctx, interval)
|
||||
operation := func() error {
|
||||
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
|
||||
}
|
||||
stream, err := c.establishStream(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to establish flow stream, retrying: %v", err)
|
||||
return c.handleRetryableError(err, time.Time{}, backOff)
|
||||
}
|
||||
|
||||
streamStart := time.Now()
|
||||
|
||||
if err := c.receive(stream, msgHandler); err != nil {
|
||||
log.Errorf("receive failed: %v", err)
|
||||
return fmt.Errorf("receive: %w", err)
|
||||
return c.handleRetryableError(err, streamStart, backOff)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -108,37 +162,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
if c.clientConn.GetState() == connectivity.Shutdown {
|
||||
return errors.New("connection to flow receiver has been shut down")
|
||||
// handleRetryableError resets the backoff timer if the stream was healthy long
|
||||
// enough and recreates the underlying ClientConn so that gRPC's internal
|
||||
// subchannel backoff does not accumulate and compete with our own retry timer.
|
||||
// A zero streamStart means the stream was never established.
|
||||
func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error {
|
||||
if isContextDone(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create event stream: %w", err)
|
||||
var permErr *backoff.PermanentError
|
||||
if errors.As(err, &permErr) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = stream.Send(&proto.FlowEvent{IsInitiator: true})
|
||||
// Reset the backoff so the next retry starts with a short delay instead of
|
||||
// continuing the already-elapsed timer. Only do this if the stream was healthy
|
||||
// long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime.
|
||||
if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration {
|
||||
backOff.Reset()
|
||||
}
|
||||
|
||||
if recreateErr := c.recreateConnection(); recreateErr != nil {
|
||||
log.Errorf("recreate connection: %v", recreateErr)
|
||||
return recreateErr
|
||||
}
|
||||
|
||||
log.Infof("connection recreated, retrying stream")
|
||||
return fmt.Errorf("retrying after error: %w", err)
|
||||
}
|
||||
|
||||
func (c *GRPCClient) recreateConnection() error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return backoff.Permanent(ErrClientClosed)
|
||||
}
|
||||
|
||||
conn, err := grpc.NewClient(c.target, c.opts...)
|
||||
if err != nil {
|
||||
log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err)
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create new connection: %w", err)
|
||||
}
|
||||
|
||||
old := c.clientConn
|
||||
c.clientConn = conn
|
||||
c.realClient = proto.NewFlowServiceClient(conn)
|
||||
c.stream = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
_ = old.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, backoff.Permanent(ErrClientClosed)
|
||||
}
|
||||
cl := c.realClient
|
||||
c.mu.Unlock()
|
||||
|
||||
// open stream outside the lock — blocking operation
|
||||
stream, err := cl.Events(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create event stream: %w", err)
|
||||
}
|
||||
streamReady := false
|
||||
defer func() {
|
||||
if !streamReady {
|
||||
_ = stream.CloseSend()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil {
|
||||
return nil, fmt.Errorf("send initiator: %w", err)
|
||||
}
|
||||
|
||||
if err = checkHeader(stream); err != nil {
|
||||
return fmt.Errorf("check header: %w", err)
|
||||
return nil, fmt.Errorf("check header: %w", err)
|
||||
}
|
||||
|
||||
c.streamMu.Lock()
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, backoff.Permanent(ErrClientClosed)
|
||||
}
|
||||
c.stream = stream
|
||||
c.streamMu.Unlock()
|
||||
c.mu.Unlock()
|
||||
streamReady = true
|
||||
|
||||
return c.receive(stream, msgHandler)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive from stream: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.IsInitiator {
|
||||
@@ -169,7 +292,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
|
||||
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 1,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.7,
|
||||
MaxInterval: interval / 2,
|
||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||
@@ -178,18 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||
c.streamMu.Lock()
|
||||
stream := c.stream
|
||||
c.streamMu.Unlock()
|
||||
|
||||
if stream == nil {
|
||||
return errors.New("stream not initialized")
|
||||
func isContextDone(err error) bool {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
if err := stream.Send(event); err != nil {
|
||||
return fmt.Errorf("send flow event: %w", err)
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,6 +14,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
flow "github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
@@ -18,21 +23,89 @@ import (
|
||||
|
||||
type testServer struct {
|
||||
proto.UnimplementedFlowServiceServer
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
listener *connTrackListener
|
||||
closeStream chan struct{} // signal server to close the stream
|
||||
handlerDone chan struct{} // signaled each time Events() exits
|
||||
handlerStarted chan struct{} // signaled each time Events() begins
|
||||
}
|
||||
|
||||
// connTrackListener wraps a net.Listener to track accepted connections
|
||||
// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM.
|
||||
type connTrackListener struct {
|
||||
net.Listener
|
||||
mu sync.Mutex
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
func (l *connTrackListener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.conns = append(l.conns, c)
|
||||
l.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR
|
||||
// (error code 0x1) on every tracked connection. This produces the exact error:
|
||||
//
|
||||
// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR
|
||||
//
|
||||
// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload):
|
||||
//
|
||||
// Length (3 bytes): 0x000004
|
||||
// Type (1 byte): 0x03 (RST_STREAM)
|
||||
// Flags (1 byte): 0x00
|
||||
// Stream ID (4 bytes): target stream (must have bit 31 clear)
|
||||
// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR)
|
||||
func (l *connTrackListener) connCount() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return len(l.conns)
|
||||
}
|
||||
|
||||
func (l *connTrackListener) sendRSTStream(streamID uint32) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
frame := make([]byte, 13) // 9-byte header + 4-byte payload
|
||||
// Length = 4 (3 bytes, big-endian)
|
||||
frame[0], frame[1], frame[2] = 0, 0, 4
|
||||
// Type = RST_STREAM (0x03)
|
||||
frame[3] = 0x03
|
||||
// Flags = 0
|
||||
frame[4] = 0x00
|
||||
// Stream ID (4 bytes, big-endian, bit 31 reserved = 0)
|
||||
binary.BigEndian.PutUint32(frame[5:9], streamID)
|
||||
// Error Code = PROTOCOL_ERROR (0x1)
|
||||
binary.BigEndian.PutUint32(frame[9:13], 0x1)
|
||||
|
||||
for _, c := range l.conns {
|
||||
_, _ = c.Write(frame)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *testServer {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
rawListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
listener := &connTrackListener{Listener: rawListener}
|
||||
|
||||
s := &testServer{
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: listener.Addr().String(),
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: rawListener.Addr().String(),
|
||||
listener: listener,
|
||||
closeStream: make(chan struct{}, 1),
|
||||
handlerDone: make(chan struct{}, 10),
|
||||
handlerStarted: make(chan struct{}, 10),
|
||||
}
|
||||
|
||||
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||
@@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer {
|
||||
}
|
||||
|
||||
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
defer func() {
|
||||
select {
|
||||
case s.handlerDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case s.handlerStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(stream.Context())
|
||||
defer cancel()
|
||||
|
||||
@@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
if err := stream.Send(ack); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-s.closeStream:
|
||||
return status.Errorf(codes.Internal, "server closing stream")
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -110,16 +197,13 @@ func TestReceive(t *testing.T) {
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
receivedAcks := make(map[string]bool)
|
||||
var ackCount atomic.Int32
|
||||
receiveDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||
id := string(msg.EventId)
|
||||
receivedAcks[id] = true
|
||||
|
||||
if len(receivedAcks) >= 3 {
|
||||
if ackCount.Add(1) >= 3 {
|
||||
close(receiveDone)
|
||||
}
|
||||
}
|
||||
@@ -130,7 +214,11 @@ func TestReceive(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
select {
|
||||
case <-server.handlerStarted:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for stream to be established")
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
eventID := uuid.New().String()
|
||||
@@ -153,7 +241,7 @@ func TestReceive(t *testing.T) {
|
||||
t.Fatal("timeout waiting for acks to be processed")
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(receivedAcks))
|
||||
assert.Equal(t, int32(3), ackCount.Load())
|
||||
}
|
||||
|
||||
func TestReceive_ContextCancellation(t *testing.T) {
|
||||
@@ -254,3 +342,195 @@ func TestSend(t *testing.T) {
|
||||
t.Fatal("timeout waiting for ack to be received by flow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_PermanentClose(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.ErrorIs(t, err, flow.ErrClientClosed)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive did not return after Close — stuck in retry loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_CloseVerify(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
closeDone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
_ = client.Close()
|
||||
closeDone <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.Error(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive did not return after Close — stuck in retry loop")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-closeDone:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Close did not return — blocked in retry loop")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestClose_WhileReceiving(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background() // no timeout — intentional
|
||||
receiveDone := make(chan struct{})
|
||||
go func() {
|
||||
_ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
return nil
|
||||
})
|
||||
close(receiveDone)
|
||||
}()
|
||||
|
||||
// Wait for the server-side handler to confirm the stream is established.
|
||||
select {
|
||||
case <-server.handlerStarted:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for stream to be established")
|
||||
}
|
||||
|
||||
closeDone := make(chan struct{})
|
||||
go func() {
|
||||
_ = client.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-closeDone:
|
||||
// Close returned — good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Close blocked forever — Receive stuck in retry loop")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive did not exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
// Track acks received before and after server-side stream close
|
||||
var ackCount atomic.Int32
|
||||
receivedFirst := make(chan struct{})
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
}
|
||||
n := ackCount.Add(1)
|
||||
if n == 1 {
|
||||
close(receivedFirst)
|
||||
}
|
||||
if n == 2 {
|
||||
close(receivedAfterReconnect)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for stream to be established, then send first ack
|
||||
select {
|
||||
case <-server.handlerStarted:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for stream to be established")
|
||||
}
|
||||
server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")}
|
||||
|
||||
select {
|
||||
case <-receivedFirst:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for first ack")
|
||||
}
|
||||
|
||||
// Snapshot connection count before injecting the fault.
|
||||
connsBefore := server.listener.connCount()
|
||||
|
||||
// Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection.
|
||||
// gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated).
|
||||
// Stream ID 1 is the client's first stream (our Events bidi stream).
|
||||
// This produces the exact error the client sees in production:
|
||||
// "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR"
|
||||
server.listener.sendRSTStream(1)
|
||||
|
||||
// Wait for the old Events() handler to fully exit so it can no longer
|
||||
// drain s.acks and drop our injected ack on a broken stream.
|
||||
select {
|
||||
case <-server.handlerDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("old Events() handler did not exit after RST_STREAM")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return server.listener.connCount() > connsBefore
|
||||
}, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")
|
||||
|
||||
server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")}
|
||||
|
||||
select {
|
||||
case <-receivedAfterReconnect:
|
||||
// Client successfully reconnected and received ack after server-side stream close
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect")
|
||||
}
|
||||
|
||||
assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close")
|
||||
assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)")
|
||||
}
|
||||
|
||||
@@ -4751,6 +4751,7 @@ components:
|
||||
enum:
|
||||
- email
|
||||
- webhook
|
||||
- slack
|
||||
example: "email"
|
||||
NotificationEventType:
|
||||
type: string
|
||||
@@ -4804,6 +4805,7 @@ components:
|
||||
Channel-specific target configuration. The shape depends on the `type` field:
|
||||
- `email`: requires an `EmailTarget` object
|
||||
- `webhook`: requires a `WebhookTarget` object
|
||||
- `slack`: requires a `WebhookTarget` object
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/EmailTarget'
|
||||
- $ref: '#/components/schemas/WebhookTarget'
|
||||
@@ -4837,6 +4839,7 @@ components:
|
||||
Channel-specific target configuration. The shape depends on the `type` field:
|
||||
- `email`: an `EmailTarget` object
|
||||
- `webhook`: a `WebhookTarget` object
|
||||
- `slack`: a `WebhookTarget` object
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/EmailTarget'
|
||||
- $ref: '#/components/schemas/WebhookTarget'
|
||||
|
||||
@@ -686,6 +686,7 @@ func (e NetworkResourceType) Valid() bool {
|
||||
// Defines values for NotificationChannelType.
|
||||
const (
|
||||
NotificationChannelTypeEmail NotificationChannelType = "email"
|
||||
NotificationChannelTypeSlack NotificationChannelType = "slack"
|
||||
NotificationChannelTypeWebhook NotificationChannelType = "webhook"
|
||||
)
|
||||
|
||||
@@ -694,6 +695,8 @@ func (e NotificationChannelType) Valid() bool {
|
||||
switch e {
|
||||
case NotificationChannelTypeEmail:
|
||||
return true
|
||||
case NotificationChannelTypeSlack:
|
||||
return true
|
||||
case NotificationChannelTypeWebhook:
|
||||
return true
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user