Revert "Merge branch 'main' into feature/remote-debug"

This reverts commit 6d6333058c, reversing
changes made to 446aded1f7.
This commit is contained in:
aliamerj
2025-10-06 12:24:48 +03:00
parent 6d6333058c
commit ba7793ae7b
288 changed files with 3117 additions and 8952 deletions

View File

@@ -9,7 +9,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
@@ -144,12 +143,10 @@ type Client struct {
listenerMutex sync.Mutex
stateSubscription *PeersStateSubscription
mtu uint16
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
@@ -158,7 +155,6 @@ func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string,
connectionURL: serverURL,
authTokenStore: authTokenStore,
hashedID: hashedID,
mtu: mtu,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
@@ -296,16 +292,7 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
// Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues
var dialers []dialer.DialeFn
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
dialers = []dialer.DialeFn{ws.Dialer{}}
} else {
dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
}
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return nil, err

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
"github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/util"
@@ -64,7 +63,7 @@ func TestClient(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
t.Log("alice connecting to server")
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -72,7 +71,7 @@ func TestClient(t *testing.T) {
defer clientAlice.Close()
t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU)
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -80,7 +79,7 @@ func TestClient(t *testing.T) {
defer clientPlaceHolder.Close()
t.Log("Bob connecting to server")
clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -138,7 +137,7 @@ func TestRegistration(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
_ = srv.Shutdown(ctx)
@@ -178,7 +177,7 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err == nil {
t.Errorf("failed to connect to server: %s", err)
@@ -219,7 +218,7 @@ func TestEcho(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -231,7 +230,7 @@ func TestEcho(t *testing.T) {
}
}()
clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU)
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -309,7 +308,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -355,13 +354,13 @@ func TestBindReconnect(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -383,7 +382,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -456,13 +455,13 @@ func TestCloseConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
bob := NewClient(serverURL, hmacTokenStore, "bob")
err = bob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -518,13 +517,13 @@ func TestCloseRelayConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
bob := NewClient(serverURL, hmacTokenStore, "bob")
err = bob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -572,7 +571,7 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
if err = relayClient.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
@@ -628,7 +627,7 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect(ctx)
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
@@ -679,7 +678,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -691,7 +690,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
}
}()
clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU)
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {

View File

@@ -63,25 +63,20 @@ type Manager struct {
onDisconnectedListeners map[string]*list.List
onReconnectedListenerFn func()
listenerLock sync.Mutex
mtu uint16
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager {
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
tokenStore := &relayAuth.TokenStore{}
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TokenStore: tokenStore,
PeerID: peerID,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
@@ -258,7 +253,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu)
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err

View File

@@ -8,15 +8,14 @@ import (
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
"github.com/netbirdio/netbird/relay/server"
)
func TestEmptyURL(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice", iface.DefaultMTU)
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
@@ -91,12 +90,12 @@ func TestForeignConn(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice", iface.DefaultMTU)
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
if err := clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU)
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
@@ -198,12 +197,12 @@ func TestForeginConnClose(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU)
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := mgrBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
mgr := NewManager(mCtx, toURL(srvCfg1), "alice", iface.DefaultMTU)
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
@@ -283,7 +282,7 @@ func TestForeignAutoClose(t *testing.T) {
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice, iface.DefaultMTU)
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
@@ -354,13 +353,13 @@ func TestAutoReconnect(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientBob := NewManager(mCtx, toURL(srvCfg), "bob", iface.DefaultMTU)
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU)
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
@@ -429,12 +428,12 @@ func TestNotifierDoubleAdd(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob", iface.DefaultMTU)
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
if err = clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice", iface.DefaultMTU)
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
if err = clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}

View File

@@ -13,8 +13,11 @@ import (
)
const (
maxConcurrentServers = 7
defaultConnectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
var (
connectionTimeout = 30 * time.Second
)
type connResult struct {
@@ -24,15 +27,13 @@ type connResult struct {
}
type ServerPicker struct {
TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout)
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(sp.ServerURLs.Load().([]string))
@@ -69,7 +70,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,

View File

@@ -8,15 +8,15 @@ import (
)
func TestServerPicker_UnavailableServers(t *testing.T) {
timeout := 5 * time.Second
connectionTimeout = 5 * time.Second
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
ConnectionTimeout: timeout,
TokenStore: nil,
PeerID: "test",
}
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), timeout+1)
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel()
go func() {

View File

@@ -1,24 +0,0 @@
package healthcheck
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -1,36 +0,0 @@
package healthcheck
import (
"os"
"testing"
)
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -7,15 +7,10 @@ import (
log "github.com/sirupsen/logrus"
)
const (
defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second
var (
heartbeatTimeout = healthCheckInterval + 10*time.Second
)
type ReceiverOptions struct {
HeartbeatTimeout time.Duration
AttemptThreshold int
}
// Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
@@ -32,23 +27,6 @@ type Receiver struct {
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver(log *log.Entry) *Receiver {
opts := ReceiverOptions{
HeartbeatTimeout: defaultHeartbeatTimeout,
AttemptThreshold: getAttemptThresholdFromEnv(),
}
return NewReceiverWithOpts(log, opts)
}
func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
heartbeatTimeout := opts.HeartbeatTimeout
if heartbeatTimeout <= 0 {
heartbeatTimeout = defaultHeartbeatTimeout
}
attemptThreshold := opts.AttemptThreshold
if attemptThreshold <= 0 {
attemptThreshold = defaultAttemptThreshold
}
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
@@ -57,10 +35,10 @@ func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: attemptThreshold,
attemptThreshold: getAttemptThresholdFromEnv(),
}
go r.waitForHealthcheck(heartbeatTimeout)
go r.waitForHealthcheck()
return r
}
@@ -77,7 +55,7 @@ func (r *Receiver) Stop() {
r.ctxCancel()
}
func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) {
func (r *Receiver) waitForHealthcheck() {
ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop()
defer r.ctxCancel()

View File

@@ -2,18 +2,31 @@ package healthcheck
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestNewReceiver(t *testing.T) {
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
opts := ReceiverOptions{
HeartbeatTimeout: 5 * time.Second,
}
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
@@ -25,10 +38,18 @@ func TestNewReceiver(t *testing.T) {
}
func TestNewReceiverNotReceive(t *testing.T) {
opts := ReceiverOptions{
HeartbeatTimeout: 1 * time.Second,
}
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
@@ -40,10 +61,18 @@ func TestNewReceiverNotReceive(t *testing.T) {
}
func TestNewReceiverAck(t *testing.T) {
opts := ReceiverOptions{
HeartbeatTimeout: 2 * time.Second,
}
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat()
@@ -68,19 +97,30 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
healthCheckInterval := 1 * time.Second
testMutex.Lock()
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
opts := ReceiverOptions{
HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond,
AttemptThreshold: tc.threshold,
}
defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
@@ -94,6 +134,7 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -2,76 +2,52 @@ package healthcheck
import (
"context"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultHealthCheckInterval = 25 * time.Second
defaultHealthCheckTimeout = 20 * time.Second
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
type SenderOptions struct {
HealthCheckInterval time.Duration
HealthCheckTimeout time.Duration
AttemptThreshold int
}
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 20 * time.Second
)
// Sender is a healthcheck sender
// It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
log *log.Entry
healthCheckInterval time.Duration
timeout time.Duration
ack chan struct{}
alive bool
attemptThreshold int
}
func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender {
if opts.HealthCheckInterval <= 0 {
opts.HealthCheckInterval = defaultHealthCheckInterval
}
if opts.HealthCheckTimeout <= 0 {
opts.HealthCheckTimeout = defaultHealthCheckTimeout
}
if opts.AttemptThreshold <= 0 {
opts.AttemptThreshold = defaultAttemptThreshold
}
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
hc := &Sender{
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
log: log,
healthCheckInterval: opts.HealthCheckInterval,
timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout,
ack: make(chan struct{}, 1),
attemptThreshold: opts.AttemptThreshold,
log: log,
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
return hc
}
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
opts := SenderOptions{
HealthCheckInterval: defaultHealthCheckInterval,
HealthCheckTimeout: defaultHealthCheckTimeout,
AttemptThreshold: getAttemptThresholdFromEnv(),
}
return NewSenderWithOpts(log, opts)
}
// OnHCResponse sends an acknowledgment signal to the sender
func (hc *Sender) OnHCResponse() {
select {
@@ -81,10 +57,10 @@ func (hc *Sender) OnHCResponse() {
}
func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(hc.healthCheckInterval)
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTicker := time.NewTicker(hc.timeout)
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
@@ -116,3 +92,19 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
}
}
}
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,23 +2,26 @@ package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
var (
testOpts = SenderOptions{
HealthCheckInterval: 2 * time.Second,
HealthCheckTimeout: 100 * time.Millisecond,
}
)
func TestMain(m *testing.M) {
// override the health check interval to speed up the test
healthCheckInterval = 2 * time.Second
healthCheckTimeout = 100 * time.Millisecond
code := m.Run()
os.Exit(code)
}
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -29,7 +32,7 @@ func TestNewHealthPeriod(t *testing.T) {
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
@@ -38,19 +41,19 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
select {
case <-hc.Timeout:
case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond):
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out")
}
}
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -86,7 +89,7 @@ func TestTimeoutReset(t *testing.T) {
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
@@ -115,16 +118,19 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
opts := SenderOptions{
HealthCheckInterval: 1 * time.Second,
HealthCheckTimeout: 500 * time.Millisecond,
AttemptThreshold: tc.threshold,
}
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts)
sender := NewSender(log.WithField("test_name", tc.name))
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
@@ -149,7 +155,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
}
}()
testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
@@ -169,7 +175,39 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}