mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Merge branch 'main' into feature/remote-debug
This commit is contained in:
@@ -9,34 +9,30 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -73,13 +69,31 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
jobManager := mgmt.NewJobManager(nil, store)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.
|
||||
EXPECT().
|
||||
ValidateUserPermissions(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
peersManger := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
@@ -110,6 +124,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
@@ -53,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
|
||||
@@ -26,7 +26,7 @@ func ptr[T any, PT *T](x T) PT {
|
||||
|
||||
func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
|
||||
t.Helper()
|
||||
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false)
|
||||
handler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false)
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
|
||||
|
||||
@@ -278,6 +278,10 @@ components:
|
||||
description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
|
||||
type: boolean
|
||||
example: true
|
||||
user_approval_required:
|
||||
description: Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.
|
||||
type: boolean
|
||||
example: false
|
||||
network_traffic_logs_enabled:
|
||||
description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
|
||||
type: boolean
|
||||
@@ -294,6 +298,7 @@ components:
|
||||
example: true
|
||||
required:
|
||||
- peer_approval_enabled
|
||||
- user_approval_required
|
||||
- network_traffic_logs_enabled
|
||||
- network_traffic_logs_groups
|
||||
- network_traffic_packet_counter_enabled
|
||||
@@ -355,6 +360,10 @@ components:
|
||||
description: Is true if this user is blocked. Blocked users can't use the system
|
||||
type: boolean
|
||||
example: false
|
||||
pending_approval:
|
||||
description: Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled.
|
||||
type: boolean
|
||||
example: false
|
||||
issued:
|
||||
description: How user was issued by API or Integration
|
||||
type: string
|
||||
@@ -369,6 +378,7 @@ components:
|
||||
- auto_groups
|
||||
- status
|
||||
- is_blocked
|
||||
- pending_approval
|
||||
UserPermissions:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1462,6 +1472,10 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "chacbco6lnnbn6cg5s91"
|
||||
skip_auto_apply:
|
||||
description: Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing
|
||||
type: boolean
|
||||
example: false
|
||||
required:
|
||||
- id
|
||||
- description
|
||||
@@ -2764,6 +2778,63 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/approve:
|
||||
post:
|
||||
summary: Approve user
|
||||
description: Approve a user that is pending approval
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a user
|
||||
responses:
|
||||
'200':
|
||||
description: Returns the approved user
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
"$ref": "#/components/schemas/User"
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/reject:
|
||||
delete:
|
||||
summary: Reject user
|
||||
description: Reject a user that is pending approval by removing them from the account
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a user
|
||||
responses:
|
||||
'200':
|
||||
description: User rejected successfully
|
||||
content: {}
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/current:
|
||||
get:
|
||||
summary: Retrieve current user
|
||||
|
||||
@@ -284,6 +284,9 @@ type AccountExtraSettings struct {
|
||||
|
||||
// PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
|
||||
PeerApprovalEnabled bool `json:"peer_approval_enabled"`
|
||||
|
||||
// UserApprovalRequired Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.
|
||||
UserApprovalRequired bool `json:"user_approval_required"`
|
||||
}
|
||||
|
||||
// AccountOnboarding defines model for AccountOnboarding.
|
||||
@@ -1619,6 +1622,9 @@ type Route struct {
|
||||
|
||||
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
|
||||
PeerGroups *[]string `json:"peer_groups,omitempty"`
|
||||
|
||||
// SkipAutoApply Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing
|
||||
SkipAutoApply *bool `json:"skip_auto_apply,omitempty"`
|
||||
}
|
||||
|
||||
// RouteRequest defines model for RouteRequest.
|
||||
@@ -1658,6 +1664,9 @@ type RouteRequest struct {
|
||||
|
||||
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
|
||||
PeerGroups *[]string `json:"peer_groups,omitempty"`
|
||||
|
||||
// SkipAutoApply Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing
|
||||
SkipAutoApply *bool `json:"skip_auto_apply,omitempty"`
|
||||
}
|
||||
|
||||
// RulePortRange Policy rule affected ports range
|
||||
@@ -1846,8 +1855,11 @@ type User struct {
|
||||
LastLogin *time.Time `json:"last_login,omitempty"`
|
||||
|
||||
// Name User's name from idp provider
|
||||
Name string `json:"name"`
|
||||
Permissions *UserPermissions `json:"permissions,omitempty"`
|
||||
Name string `json:"name"`
|
||||
|
||||
// PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled.
|
||||
PendingApproval bool `json:"pending_approval"`
|
||||
Permissions *UserPermissions `json:"permissions,omitempty"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.12
|
||||
// protoc v4.24.3
|
||||
// source: management.proto
|
||||
|
||||
package proto
|
||||
@@ -1982,6 +1982,7 @@ type PeerConfig struct {
|
||||
Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
|
||||
RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"`
|
||||
LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"`
|
||||
Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PeerConfig) Reset() {
|
||||
@@ -2058,6 +2059,13 @@ func (x *PeerConfig) GetLazyConnectionEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *PeerConfig) GetMtu() int32 {
|
||||
if x != nil {
|
||||
return x.Mtu
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
||||
type NetworkMap struct {
|
||||
state protoimpl.MessageState
|
||||
@@ -2693,15 +2701,16 @@ type Route struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"`
|
||||
Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"`
|
||||
NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"`
|
||||
Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"`
|
||||
Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"`
|
||||
Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"`
|
||||
NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"`
|
||||
Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"`
|
||||
KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"`
|
||||
ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"`
|
||||
Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"`
|
||||
NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"`
|
||||
Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"`
|
||||
Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"`
|
||||
Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"`
|
||||
NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"`
|
||||
Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"`
|
||||
KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"`
|
||||
SkipAutoApply bool `protobuf:"varint,10,opt,name=skipAutoApply,proto3" json:"skipAutoApply,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Route) Reset() {
|
||||
@@ -2799,6 +2808,13 @@ func (x *Route) GetKeepRoute() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *Route) GetSkipAutoApply() bool {
|
||||
if x != nil {
|
||||
return x.SkipAutoApply
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DNSConfig represents a dns.Update
|
||||
type DNSConfig struct {
|
||||
state protoimpl.MessageState
|
||||
|
||||
@@ -303,6 +303,8 @@ message PeerConfig {
|
||||
bool RoutingPeerDnsResolutionEnabled = 5;
|
||||
|
||||
bool LazyConnectionEnabled = 6;
|
||||
|
||||
int32 mtu = 7;
|
||||
}
|
||||
|
||||
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
||||
@@ -439,6 +441,7 @@ message Route {
|
||||
string NetID = 7;
|
||||
repeated string Domains = 8;
|
||||
bool keepRoute = 9;
|
||||
bool skipAutoApply = 10;
|
||||
}
|
||||
|
||||
// DNSConfig represents a dns.Update
|
||||
|
||||
@@ -42,7 +42,10 @@ const (
|
||||
// Type is a type of the Error
|
||||
type Type int32
|
||||
|
||||
var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found")
|
||||
var (
|
||||
ErrExtraSettingsNotFound = errors.New("extra settings not found")
|
||||
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
|
||||
)
|
||||
|
||||
// Error is an internal error
|
||||
type Error struct {
|
||||
@@ -110,6 +113,11 @@ func NewUserBlockedError() error {
|
||||
return Errorf(PermissionDenied, "user is blocked")
|
||||
}
|
||||
|
||||
// NewUserPendingApprovalError creates a new Error with PermissionDenied type for a blocked user pending approval
|
||||
func NewUserPendingApprovalError() error {
|
||||
return Errorf(PermissionDenied, "user is pending approval")
|
||||
}
|
||||
|
||||
// NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer
|
||||
func NewPeerNotRegisteredError() error {
|
||||
return Errorf(Unauthenticated, "peer is not registered")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
@@ -143,10 +144,12 @@ type Client struct {
|
||||
listenerMutex sync.Mutex
|
||||
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
@@ -155,6 +158,7 @@ func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string)
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
mtu: mtu,
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufferSize)
|
||||
@@ -292,7 +296,16 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
// Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues
|
||||
var dialers []dialer.DialeFn
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
dialers = []dialer.DialeFn{ws.Dialer{}}
|
||||
} else {
|
||||
dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
}
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -63,7 +64,7 @@ func TestClient(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
t.Log("alice connecting to server")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -71,7 +72,7 @@ func TestClient(t *testing.T) {
|
||||
defer clientAlice.Close()
|
||||
|
||||
t.Log("placeholder connecting to server")
|
||||
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
|
||||
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU)
|
||||
err = clientPlaceHolder.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -79,7 +80,7 @@ func TestClient(t *testing.T) {
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
t.Log("Bob connecting to server")
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -137,7 +138,7 @@ func TestRegistration(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
_ = srv.Shutdown(ctx)
|
||||
@@ -177,7 +178,7 @@ func TestRegistrationTimeout(t *testing.T) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
|
||||
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -218,7 +219,7 @@ func TestEcho(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -230,7 +231,7 @@ func TestEcho(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -308,7 +309,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -354,13 +355,13 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -382,7 +383,7 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -455,13 +456,13 @@ func TestCloseConn(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = bob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -517,13 +518,13 @@ func TestCloseRelayConn(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = bob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -571,7 +572,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
if err = relayClient.Connect(ctx); err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
@@ -627,7 +628,7 @@ func TestCloseByClient(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
err = relayClient.Connect(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -678,7 +679,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -690,7 +691,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
quictls "github.com/netbirdio/netbird/shared/relay/tls"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
|
||||
@@ -63,20 +63,25 @@ type Manager struct {
|
||||
onDisconnectedListeners map[string]*list.List
|
||||
onReconnectedListenerFn func()
|
||||
listenerLock sync.Mutex
|
||||
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
@@ -253,7 +258,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
|
||||
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
|
||||
@@ -8,14 +8,15 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
)
|
||||
|
||||
func TestEmptyURL(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
mgr := NewManager(ctx, nil, "alice")
|
||||
mgr := NewManager(ctx, nil, "alice", iface.DefaultMTU)
|
||||
err := mgr.Serve()
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
@@ -90,12 +91,12 @@ func TestForeignConn(t *testing.T) {
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
|
||||
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice", iface.DefaultMTU)
|
||||
if err := clientAlice.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU)
|
||||
if err := clientBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
@@ -197,12 +198,12 @@ func TestForeginConnClose(t *testing.T) {
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU)
|
||||
if err := mgrBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), "alice", iface.DefaultMTU)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
@@ -282,7 +283,7 @@ func TestForeignAutoClose(t *testing.T) {
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice, iface.DefaultMTU)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
@@ -353,13 +354,13 @@ func TestAutoReconnect(t *testing.T) {
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg), "bob", iface.DefaultMTU)
|
||||
err = clientBob.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
@@ -428,12 +429,12 @@ func TestNotifierDoubleAdd(t *testing.T) {
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
|
||||
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob", iface.DefaultMTU)
|
||||
if err = clientBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
|
||||
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice", iface.DefaultMTU)
|
||||
if err = clientAlice.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
@@ -13,11 +13,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
defaultConnectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
@@ -27,13 +24,15 @@ type connResult struct {
|
||||
}
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(sp.ServerURLs.Load().([]string))
|
||||
@@ -70,7 +69,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
connectionTimeout = 5 * time.Second
|
||||
|
||||
timeout := 5 * time.Second
|
||||
sp := ServerPicker{
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
ConnectionTimeout: timeout,
|
||||
}
|
||||
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
|
||||
24
shared/relay/healthcheck/env.go
Normal file
24
shared/relay/healthcheck/env.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||
)
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
return int(threshold)
|
||||
}
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
36
shared/relay/healthcheck/env_test.go
Normal file
36
shared/relay/healthcheck/env_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//nolint:tenv
|
||||
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
} else {
|
||||
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||
}
|
||||
|
||||
result := getAttemptThresholdFromEnv()
|
||||
if result != tt.expected {
|
||||
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,10 +7,15 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
heartbeatTimeout = healthCheckInterval + 10*time.Second
|
||||
const (
|
||||
defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second
|
||||
)
|
||||
|
||||
type ReceiverOptions struct {
|
||||
HeartbeatTimeout time.Duration
|
||||
AttemptThreshold int
|
||||
}
|
||||
|
||||
// Receiver is a healthcheck receiver
|
||||
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
|
||||
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
||||
@@ -27,6 +32,23 @@ type Receiver struct {
|
||||
|
||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||
func NewReceiver(log *log.Entry) *Receiver {
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: defaultHeartbeatTimeout,
|
||||
AttemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
return NewReceiverWithOpts(log, opts)
|
||||
}
|
||||
|
||||
func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
|
||||
heartbeatTimeout := opts.HeartbeatTimeout
|
||||
if heartbeatTimeout <= 0 {
|
||||
heartbeatTimeout = defaultHeartbeatTimeout
|
||||
}
|
||||
attemptThreshold := opts.AttemptThreshold
|
||||
if attemptThreshold <= 0 {
|
||||
attemptThreshold = defaultAttemptThreshold
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
r := &Receiver{
|
||||
@@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver {
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
attemptThreshold: attemptThreshold,
|
||||
}
|
||||
|
||||
go r.waitForHealthcheck()
|
||||
go r.waitForHealthcheck(heartbeatTimeout)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -55,7 +77,7 @@ func (r *Receiver) Stop() {
|
||||
r.ctxCancel()
|
||||
}
|
||||
|
||||
func (r *Receiver) waitForHealthcheck() {
|
||||
func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) {
|
||||
ticker := time.NewTicker(heartbeatTimeout)
|
||||
defer ticker.Stop()
|
||||
defer r.ctxCancel()
|
||||
|
||||
@@ -2,31 +2,18 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Mutex to protect global variable access in tests
|
||||
var testMutex sync.Mutex
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: 5 * time.Second,
|
||||
}
|
||||
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
@@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: 1 * time.Second,
|
||||
}
|
||||
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
@@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: 2 * time.Second,
|
||||
}
|
||||
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
|
||||
defer r.Stop()
|
||||
|
||||
r.Heartbeat()
|
||||
@@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := heartbeatTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||
testMutex.Unlock()
|
||||
healthCheckInterval := 1 * time.Second
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
healthCheckInterval = originalInterval
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond,
|
||||
AttemptThreshold: tc.threshold,
|
||||
}
|
||||
|
||||
receiver := NewReceiver(log.WithField("test_name", tc.name))
|
||||
receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts)
|
||||
|
||||
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
if tc.resetCounterOnce {
|
||||
receiver.Heartbeat()
|
||||
t.Logf("reset counter once")
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
}
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,52 +2,76 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAttemptThreshold = 1
|
||||
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||
defaultAttemptThreshold = 1
|
||||
|
||||
defaultHealthCheckInterval = 25 * time.Second
|
||||
defaultHealthCheckTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
healthCheckInterval = 25 * time.Second
|
||||
healthCheckTimeout = 20 * time.Second
|
||||
)
|
||||
type SenderOptions struct {
|
||||
HealthCheckInterval time.Duration
|
||||
HealthCheckTimeout time.Duration
|
||||
AttemptThreshold int
|
||||
}
|
||||
|
||||
// Sender is a healthcheck sender
|
||||
// It will send healthcheck signal to the receiver
|
||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||
// It will also stop if the context is canceled
|
||||
type Sender struct {
|
||||
log *log.Entry
|
||||
// HealthCheck is a channel to send health check signal to the peer
|
||||
HealthCheck chan struct{}
|
||||
// Timeout is a channel to the health check signal is not received in a certain time
|
||||
Timeout chan struct{}
|
||||
|
||||
log *log.Entry
|
||||
healthCheckInterval time.Duration
|
||||
timeout time.Duration
|
||||
|
||||
ack chan struct{}
|
||||
alive bool
|
||||
attemptThreshold int
|
||||
}
|
||||
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender(log *log.Entry) *Sender {
|
||||
func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender {
|
||||
if opts.HealthCheckInterval <= 0 {
|
||||
opts.HealthCheckInterval = defaultHealthCheckInterval
|
||||
}
|
||||
if opts.HealthCheckTimeout <= 0 {
|
||||
opts.HealthCheckTimeout = defaultHealthCheckTimeout
|
||||
}
|
||||
if opts.AttemptThreshold <= 0 {
|
||||
opts.AttemptThreshold = defaultAttemptThreshold
|
||||
}
|
||||
hc := &Sender{
|
||||
log: log,
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
ack: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
log: log,
|
||||
healthCheckInterval: opts.HealthCheckInterval,
|
||||
timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout,
|
||||
ack: make(chan struct{}, 1),
|
||||
attemptThreshold: opts.AttemptThreshold,
|
||||
}
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender(log *log.Entry) *Sender {
|
||||
opts := SenderOptions{
|
||||
HealthCheckInterval: defaultHealthCheckInterval,
|
||||
HealthCheckTimeout: defaultHealthCheckTimeout,
|
||||
AttemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
return NewSenderWithOpts(log, opts)
|
||||
}
|
||||
|
||||
// OnHCResponse sends an acknowledgment signal to the sender
|
||||
func (hc *Sender) OnHCResponse() {
|
||||
select {
|
||||
@@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() {
|
||||
}
|
||||
|
||||
func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
ticker := time.NewTicker(healthCheckInterval)
|
||||
ticker := time.NewTicker(hc.healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
|
||||
timeoutTicker := time.NewTicker(hc.timeout)
|
||||
defer timeoutTicker.Stop()
|
||||
|
||||
defer close(hc.HealthCheck)
|
||||
@@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *Sender) getTimeoutTime() time.Duration {
|
||||
return healthCheckInterval + healthCheckTimeout
|
||||
}
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
return int(threshold)
|
||||
}
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
|
||||
@@ -2,26 +2,23 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// override the health check interval to speed up the test
|
||||
healthCheckInterval = 2 * time.Second
|
||||
healthCheckTimeout = 100 * time.Millisecond
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
var (
|
||||
testOpts = SenderOptions{
|
||||
HealthCheckInterval: 2 * time.Second,
|
||||
HealthCheckTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
)
|
||||
|
||||
func TestNewHealthPeriod(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
@@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) {
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
@@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) {
|
||||
func TestNewHealthFailed(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
select {
|
||||
case <-hc.Timeout:
|
||||
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
|
||||
case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond):
|
||||
t.Fatalf("health check is not timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHealthcheckStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) {
|
||||
func TestTimeoutReset(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
@@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) {
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
@@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := healthCheckTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
healthCheckTimeout = 500 * time.Millisecond
|
||||
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
opts := SenderOptions{
|
||||
HealthCheckInterval: 1 * time.Second,
|
||||
HealthCheckTimeout: 500 * time.Millisecond,
|
||||
AttemptThreshold: tc.threshold,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sender := NewSender(log.WithField("test_name", tc.name))
|
||||
sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts)
|
||||
senderExit := make(chan struct{})
|
||||
go func() {
|
||||
sender.StartHealthCheck(ctx)
|
||||
@@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
|
||||
testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval
|
||||
|
||||
select {
|
||||
case <-sender.Timeout:
|
||||
@@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("sender did not exit in time")
|
||||
}
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//nolint:tenv
|
||||
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
} else {
|
||||
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||
}
|
||||
|
||||
result := getAttemptThresholdFromEnv()
|
||||
if result != tt.expected {
|
||||
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
@@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user