mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 00:56:39 +00:00
Compare commits
4 Commits
wasm-webso
...
feature/lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37b9905b68 | ||
|
|
92e53d6319 | ||
|
|
8a7d78ddf3 | ||
|
|
ea83cbf917 |
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -27,7 +26,6 @@ const (
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
dialWebSocketTimeout = 30 * time.Second
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
@@ -518,7 +516,6 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["dialWebSocket"] = createDialWebSocketMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
@@ -528,74 +525,6 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
func createDialWebSocketMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
|
||||
if !errVal.IsUndefined() {
|
||||
return errVal
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
|
||||
if len(args) < 1 || args[0].Type() != js.TypeString {
|
||||
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
|
||||
}
|
||||
url = args[0].String()
|
||||
|
||||
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
|
||||
arr, err := jsStringArray(args[1])
|
||||
if err != nil {
|
||||
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
|
||||
}
|
||||
protocols = arr
|
||||
}
|
||||
|
||||
timeout = dialWebSocketTimeout
|
||||
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
|
||||
if args[2].Type() != js.TypeNumber {
|
||||
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
|
||||
}
|
||||
timeoutMs := args[2].Int()
|
||||
if timeoutMs <= 0 {
|
||||
return "", nil, 0, js.ValueOf("error: timeout must be positive")
|
||||
}
|
||||
timeout = time.Duration(timeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
return url, protocols, timeout, js.Undefined()
|
||||
}
|
||||
|
||||
// jsStringArray converts a JS array of strings to a Go []string.
|
||||
func jsStringArray(v js.Value) ([]string, error) {
|
||||
if !v.InstanceOf(js.Global().Get("Array")) {
|
||||
return nil, fmt.Errorf("expected array")
|
||||
}
|
||||
n := v.Length()
|
||||
out := make([]string, n)
|
||||
for i := 0; i < n; i++ {
|
||||
el := v.Index(i)
|
||||
if el.Type() != js.TypeString {
|
||||
return nil, fmt.Errorf("element %d is not a string", i)
|
||||
}
|
||||
out[i] = el.String()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type closeError struct {
|
||||
code uint16
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *closeError) Error() string {
|
||||
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
|
||||
}
|
||||
|
||||
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
|
||||
// during the WebSocket handshake before falling through to the raw conn.
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
|
||||
|
||||
// Conn wraps a WebSocket connection over a NetBird TCP connection.
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
|
||||
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
|
||||
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
|
||||
d := ws.Dialer{
|
||||
NetDial: client.Dial,
|
||||
Protocols: protocols,
|
||||
}
|
||||
|
||||
conn, br, _, err := d.Dial(ctx, rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("websocket dial: %w", err)
|
||||
}
|
||||
|
||||
// br is non-nil when the server pushed frames alongside the handshake
|
||||
// response; those bytes live in the bufio.Reader and must be drained
|
||||
// before reading from conn, otherwise we'd skip the first frames.
|
||||
if br != nil {
|
||||
if br.Buffered() > 0 {
|
||||
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
|
||||
} else {
|
||||
ws.PutReader(br)
|
||||
}
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadMessage reads the next WebSocket message, handling control frames automatically.
|
||||
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
|
||||
for {
|
||||
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
if msg.OpCode.IsControl() {
|
||||
if err := c.handleControl(msg); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return msg.OpCode, msg.Payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) handleControl(msg wsutil.Message) error {
|
||||
switch msg.OpCode {
|
||||
case ws.OpPing:
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
|
||||
case ws.OpClose:
|
||||
code, reason := parseClosePayload(msg.Payload)
|
||||
return &closeError{code: code, reason: reason}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteText sends a text WebSocket message.
|
||||
func (c *Conn) WriteText(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
|
||||
}
|
||||
|
||||
// WriteBinary sends a binary WebSocket message.
|
||||
func (c *Conn) WriteBinary(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
|
||||
}
|
||||
|
||||
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
|
||||
func (c *Conn) Close() error {
|
||||
return c.closeWith(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
|
||||
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
|
||||
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
|
||||
var first bool
|
||||
c.closeOnce.Do(func() {
|
||||
first = true
|
||||
close(c.closed)
|
||||
|
||||
c.mu.Lock()
|
||||
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
|
||||
c.mu.Unlock()
|
||||
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
if !first {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
|
||||
// It exposes: send(string|Uint8Array), close(), and callback properties
|
||||
// onmessage, onclose, onerror.
|
||||
//
|
||||
// Callback properties may be set from the JS thread while the read loop
|
||||
// goroutine reads them. In WASM this is safe because Go and JS share a
|
||||
// single thread, but the design would need synchronization on
|
||||
// multi-threaded runtimes.
|
||||
func NewJSInterface(conn *Conn) js.Value {
|
||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
log.Errorf("websocket send requires a data argument")
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
switch data.Type() {
|
||||
case js.TypeString:
|
||||
if err := conn.WriteText([]byte(data.String())); err != nil {
|
||||
log.Errorf("failed to send websocket text: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
default:
|
||||
buf, err := jsToBytes(data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert js value to bytes: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
if err := conn.WriteBinary(buf); err != nil {
|
||||
log.Errorf("failed to send websocket binary: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
}
|
||||
return js.ValueOf(true)
|
||||
})
|
||||
obj.Set("send", sendFunc)
|
||||
|
||||
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close websocket: %v", err)
|
||||
}
|
||||
return js.Undefined()
|
||||
})
|
||||
obj.Set("close", closeFunc)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("close websocket on readLoop exit: %v", err)
|
||||
}
|
||||
}()
|
||||
readLoop(conn, obj)
|
||||
// Undefining before Release turns post-close JS calls into TypeError
|
||||
// instead of a silent "call to released function".
|
||||
obj.Set("send", js.Undefined())
|
||||
obj.Set("close", js.Undefined())
|
||||
sendFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
func jsToBytes(data js.Value) ([]byte, error) {
|
||||
var uint8Array js.Value
|
||||
switch {
|
||||
case data.InstanceOf(js.Global().Get("Uint8Array")):
|
||||
uint8Array = data
|
||||
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
|
||||
uint8Array = js.Global().Get("Uint8Array").New(data)
|
||||
default:
|
||||
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
|
||||
}
|
||||
|
||||
buf := make([]byte, uint8Array.Get("length").Int())
|
||||
js.CopyBytesToGo(buf, uint8Array)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func readLoop(conn *Conn, obj js.Value) {
|
||||
var ce *closeError
|
||||
defer func() { invokeOnClose(obj, ce) }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-conn.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
op, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
ce = handleReadError(conn, obj, err)
|
||||
return
|
||||
}
|
||||
|
||||
dispatchMessage(obj, op, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
|
||||
var ce *closeError
|
||||
if errors.As(err, &ce) {
|
||||
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
|
||||
log.Debugf("failed to close websocket after server close frame: %v", cerr)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if onerror := obj.Get("onerror"); onerror.Truthy() {
|
||||
onerror.Invoke(js.ValueOf(err.Error()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func invokeOnClose(obj js.Value, ce *closeError) {
|
||||
onclose := obj.Get("onclose")
|
||||
if !onclose.Truthy() {
|
||||
return
|
||||
}
|
||||
if ce != nil {
|
||||
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
|
||||
return
|
||||
}
|
||||
onclose.Invoke()
|
||||
}
|
||||
|
||||
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
|
||||
onmessage := obj.Get("onmessage")
|
||||
if !onmessage.Truthy() {
|
||||
return
|
||||
}
|
||||
switch op {
|
||||
case ws.OpText:
|
||||
onmessage.Invoke(js.ValueOf(string(payload)))
|
||||
case ws.OpBinary:
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
|
||||
js.CopyBytesToJS(uint8Array, payload)
|
||||
onmessage.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
|
||||
func parseClosePayload(payload []byte) (uint16, string) {
|
||||
if len(payload) < 2 {
|
||||
return 1005, "" // RFC 6455: No Status Rcvd
|
||||
}
|
||||
code := binary.BigEndian.Uint16(payload[:2])
|
||||
return code, string(payload[2:])
|
||||
}
|
||||
3
go.mod
3
go.mod
@@ -52,7 +52,6 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.3
|
||||
github.com/gobwas/ws v1.4.0
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/golang/mock v1.6.0
|
||||
@@ -205,8 +204,6 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-text/render v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.1 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
|
||||
7
go.sum
7
go.sum
@@ -233,12 +233,6 @@ github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg
|
||||
github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
|
||||
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
|
||||
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
|
||||
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
|
||||
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
|
||||
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
@@ -793,7 +787,6 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
@@ -34,6 +35,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/settingoverrider"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
@@ -73,6 +75,23 @@ func (s *BaseServer) CacheStore() cachestore.StoreInterface {
|
||||
})
|
||||
}
|
||||
|
||||
// SettingOverrider returns a shared setting overrider backed by Redis.
|
||||
// Returns a no-op overrider if no Redis address is configured.
|
||||
func (s *BaseServer) SettingOverrider() *settingoverrider.Overrider {
|
||||
return Create(s, func() *settingoverrider.Overrider {
|
||||
redisAddr := nbcache.GetAddrFromEnv()
|
||||
if redisAddr == "" {
|
||||
return settingoverrider.NewNoop()
|
||||
}
|
||||
|
||||
o, err := settingoverrider.New(context.Background(), redisAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create setting overrider: %v", err)
|
||||
}
|
||||
return o
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) Store() store.Store {
|
||||
return Create(s, func() store.Store {
|
||||
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/settingoverrider"
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -123,6 +124,15 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
s.PeersManager()
|
||||
s.GeoLocationManager()
|
||||
|
||||
s.SettingOverrider().Poll(settingoverrider.DefaultInterval, "managementLogLevel", func(value string) error {
|
||||
level, err := log.ParseLevel(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing log level %q: %w", value, err)
|
||||
}
|
||||
log.SetLevel(level)
|
||||
return nil
|
||||
})
|
||||
|
||||
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to expose metrics: %v", err)
|
||||
@@ -235,6 +245,7 @@ func (s *BaseServer) Stop() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_ = s.SettingOverrider().Close()
|
||||
s.IntegratedValidator().Stop(ctx)
|
||||
if s.GeoLocationManager() != nil {
|
||||
_ = s.GeoLocationManager().Stop()
|
||||
|
||||
@@ -2311,29 +2311,6 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
// Verify the already-expired peer is excluded at the store level
|
||||
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
|
||||
}
|
||||
|
||||
// Only the non-expired peer with expiration enabled should be returned
|
||||
require.Len(t, peers, 1)
|
||||
assert.Equal(t, "notexpired01", peers[0].ID)
|
||||
}
|
||||
|
||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
@@ -3253,13 +3230,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
||||
// seconds. Only runs down on failure; passing tests return immediately
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3278,7 +3248,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
// expired peers come separately.
|
||||
if len(networkMap.GetOfflinePeers()) != 2 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
|
||||
if len(networkMap.GetOfflinePeers()) != 1 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
|
||||
}
|
||||
|
||||
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
||||
|
||||
@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1405,10 +1405,6 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
|
||||
@@ -1907,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1929,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1994,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2012,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2058,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2076,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2113,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2131,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -2070,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -2107,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2127,7 +2127,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2145,7 +2145,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2185,7 +2185,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2225,7 +2225,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||
|
||||
@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
{
|
||||
name: "should retrieve peers for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 5,
|
||||
expectedCount: 4,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for a non-existing account ID",
|
||||
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
name: "should filter peers by partial name",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
nameFilter: "host",
|
||||
expectedCount: 4,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "should filter peers by ip",
|
||||
@@ -2777,16 +2777,14 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
expectedPeerIDs []string
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "should retrieve only non-expired peers with expiration enabled",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
expectedPeerIDs: []string{"notexpired01"},
|
||||
name: "should retrieve peers with expiration for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "should return no peers with expiration for a non-existing account ID",
|
||||
@@ -2805,30 +2803,10 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
for i, peer := range peers {
|
||||
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
@@ -2909,7 +2887,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
|
||||
name: "should retrieve peers for another valid account ID and user ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectedCount: 3,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for existing account ID with empty user ID",
|
||||
|
||||
@@ -31,7 +31,6 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -433,7 +433,6 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -391,15 +391,6 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestSetSessionCookieHasRootPath(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
setSessionCookie(w, "test-token", time.Hour)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
require.Len(t, cookies, 1)
|
||||
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
120
shared/settingoverrider/overrider.go
Normal file
120
shared/settingoverrider/overrider.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package settingoverrider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// ApplyFunc is called with the raw Redis string value whenever it changes.
|
||||
// The function is responsible for parsing and applying the value.
|
||||
// Return an error to log a warning without stopping the polling loop.
|
||||
type ApplyFunc func(value string) error
|
||||
|
||||
// Overrider holds a shared Redis connection and allows registering
|
||||
// individual settings that are polled independently.
|
||||
type Overrider struct {
|
||||
client *redis.Client
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
noop bool
|
||||
}
|
||||
|
||||
// New creates an Overrider by connecting to Redis at the given address.
|
||||
// The address should follow the Redis URL format (e.g. "redis://localhost:6379").
|
||||
func New(ctx context.Context, redisAddr string) (*Overrider, error) {
|
||||
if redisAddr == "" {
|
||||
return nil, fmt.Errorf("redis address is empty")
|
||||
}
|
||||
|
||||
options, err := redis.ParseURL(redisAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing redis address: %w", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(options)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err := client.Ping(pingCtx).Result(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("connecting to redis: %w", err)
|
||||
}
|
||||
|
||||
oCtx, oCancel := context.WithCancel(ctx)
|
||||
|
||||
return &Overrider{client: client, cancel: oCancel, ctx: oCtx}, nil
|
||||
}
|
||||
|
||||
// NewNoop returns an Overrider that does nothing.
|
||||
// Poll calls are silently ignored and Close is a no-op.
|
||||
func NewNoop() *Overrider {
|
||||
return &Overrider{noop: true}
|
||||
}
|
||||
|
||||
// Close stops all polling goroutines and closes the underlying Redis client.
|
||||
func (o *Overrider) Close() error {
|
||||
if o.noop {
|
||||
return nil
|
||||
}
|
||||
o.cancel()
|
||||
return o.client.Close()
|
||||
}
|
||||
|
||||
// Poll starts a background goroutine that polls a single Redis key at the given interval
|
||||
// and calls apply whenever the value changes. The goroutine stops when the Overrider is closed.
|
||||
func (o *Overrider) Poll(interval time.Duration, redisKey string, apply ApplyFunc) {
|
||||
if o.noop {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastSeen *string
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-o.ctx.Done():
|
||||
log.WithContext(o.ctx).Infof("Stopping settings overrider for key %q", redisKey)
|
||||
return
|
||||
case <-ticker.C:
|
||||
getCtx, cancel := context.WithTimeout(o.ctx, 5*time.Second)
|
||||
val, err := o.client.Get(getCtx, redisKey).Result()
|
||||
cancel()
|
||||
|
||||
if errors.Is(err, redis.Nil) || val == "" {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if o.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.WithContext(o.ctx).Errorf("Unable to get setting %q from Redis: %v", redisKey, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if lastSeen != nil && *lastSeen == val {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := apply(val); err != nil {
|
||||
log.WithContext(o.ctx).Warnf("Failed to apply setting %q with value %q: %v", redisKey, val, err)
|
||||
continue
|
||||
}
|
||||
|
||||
lastSeen = &val
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
111
shared/settingoverrider/overrider_test.go
Normal file
111
shared/settingoverrider/overrider_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package settingoverrider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
func TestPoll_AppliesSettingFromRedis(t *testing.T) {
|
||||
o, client := setupOverrider(t)
|
||||
|
||||
key := "test-setting-key"
|
||||
require.NoError(t, client.Set(context.Background(), key, "hello", 0).Err())
|
||||
|
||||
var applied atomic.Value
|
||||
|
||||
o.Poll(100*time.Millisecond, key, func(value string) error {
|
||||
applied.Store(value)
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
v := applied.Load()
|
||||
return v != nil && v.(string) == "hello"
|
||||
}, 5*time.Second, 50*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestPoll_IndependentSettings(t *testing.T) {
|
||||
o, client := setupOverrider(t)
|
||||
|
||||
require.NoError(t, client.Set(context.Background(), "key-a", "val-a", 0).Err())
|
||||
require.NoError(t, client.Set(context.Background(), "key-b", "val-b", 0).Err())
|
||||
|
||||
var gotA, gotB atomic.Value
|
||||
|
||||
o.Poll(100*time.Millisecond, "key-a", func(v string) error { gotA.Store(v); return nil })
|
||||
o.Poll(100*time.Millisecond, "key-b", func(v string) error { gotB.Store(v); return nil })
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
a, b := gotA.Load(), gotB.Load()
|
||||
return a != nil && a.(string) == "val-a" && b != nil && b.(string) == "val-b"
|
||||
}, 5*time.Second, 50*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestPoll_SkipsDuplicateValues(t *testing.T) {
|
||||
o, client := setupOverrider(t)
|
||||
|
||||
key := "test-dedup"
|
||||
require.NoError(t, client.Set(context.Background(), key, "same", 0).Err())
|
||||
|
||||
var count atomic.Int32
|
||||
|
||||
o.Poll(100*time.Millisecond, key, func(string) error {
|
||||
count.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
||||
// wait for a few ticks
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), count.Load(), "Apply should be called only once for unchanged value")
|
||||
}
|
||||
|
||||
func setupOverrider(t *testing.T) (*Overrider, *redis.Client) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
redisContainer, err := testcontainersredis.RunContainer(ctx,
|
||||
testcontainers.WithImage("redis:7"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForListeningPort("6379/tcp"),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err, "Failed to create redis test container")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := redisContainer.Terminate(ctx); err != nil {
|
||||
t.Logf("failed to terminate redis container: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
redisURL, err := redisContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
o, err := New(ctx, redisURL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := o.Close(); err != nil {
|
||||
t.Logf("failed to close overrider: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
// separate client for test setup (setting keys)
|
||||
options, err := redis.ParseURL(redisURL)
|
||||
require.NoError(t, err)
|
||||
client := redis.NewClient(options)
|
||||
t.Cleanup(func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("failed to close redis client: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
return o, client
|
||||
}
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/shared/metrics"
|
||||
"github.com/netbirdio/netbird/shared/settingoverrider"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
@@ -114,7 +116,24 @@ var (
|
||||
}
|
||||
}()
|
||||
|
||||
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
|
||||
overrider := settingoverrider.NewNoop()
|
||||
if redisAddr := cache.GetAddrFromEnv(); redisAddr != "" {
|
||||
overrider, err = settingoverrider.New(cmd.Context(), redisAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create setting overrider: %w", err)
|
||||
}
|
||||
defer func() { _ = overrider.Close() }()
|
||||
}
|
||||
overrider.Poll(settingoverrider.DefaultInterval, "signalLogLevel", func(value string) error {
|
||||
level, err := log.ParseLevel(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing log level %q: %w", value, err)
|
||||
}
|
||||
log.SetLevel(level)
|
||||
return nil
|
||||
})
|
||||
|
||||
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, overrider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating signal server: %v", err)
|
||||
}
|
||||
|
||||
135
signal/server/send_tracker.go
Normal file
135
signal/server/send_tracker.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSendRateLogInterval = 5 * time.Minute
|
||||
defaultSendRateTopPercent = 0.95
|
||||
envSendRateLogInterval = "NB_SIGNAL_SEND_RATE_LOG_INTERVAL"
|
||||
envSendRateTopPercent = "NB_SIGNAL_SEND_RATE_LOG_TOP_PERCENT"
|
||||
)
|
||||
|
||||
// sendRateTracker tracks per-key message counts and logs the busiest peers periodically.
|
||||
type sendRateTracker struct {
|
||||
mu sync.Mutex
|
||||
counts map[string]int64
|
||||
|
||||
// atomic so they can be updated by the setting overrider without locking
|
||||
intervalNs atomic.Int64
|
||||
// topPercent stored as float64 bits for atomic access
|
||||
topPercentBits atomic.Uint64
|
||||
}
|
||||
|
||||
func newSendRateTracker() *sendRateTracker {
|
||||
interval := defaultSendRateLogInterval
|
||||
if v := os.Getenv(envSendRateLogInterval); v != "" {
|
||||
if parsed, err := time.ParseDuration(v); err == nil && parsed > 0 {
|
||||
interval = parsed
|
||||
}
|
||||
}
|
||||
|
||||
topPercent := defaultSendRateTopPercent
|
||||
if v := os.Getenv(envSendRateTopPercent); v != "" {
|
||||
if parsed, err := strconv.ParseFloat(v, 64); err == nil && parsed > 0 && parsed <= 1 {
|
||||
topPercent = parsed
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("send rate tracker: interval=%s, top_percent=%.2f", interval, topPercent)
|
||||
|
||||
t := &sendRateTracker{
|
||||
counts: make(map[string]int64),
|
||||
}
|
||||
t.intervalNs.Store(int64(interval))
|
||||
t.topPercentBits.Store(math.Float64bits(topPercent))
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *sendRateTracker) getInterval() time.Duration {
|
||||
return time.Duration(t.intervalNs.Load())
|
||||
}
|
||||
|
||||
func (t *sendRateTracker) setInterval(d time.Duration) {
|
||||
t.intervalNs.Store(int64(d))
|
||||
}
|
||||
|
||||
func (t *sendRateTracker) getTopPercent() float64 {
|
||||
return math.Float64frombits(t.topPercentBits.Load())
|
||||
}
|
||||
|
||||
func (t *sendRateTracker) setTopPercent(p float64) {
|
||||
t.topPercentBits.Store(math.Float64bits(p))
|
||||
}
|
||||
|
||||
func (t *sendRateTracker) increment(key string) {
|
||||
t.mu.Lock()
|
||||
t.counts[key]++
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// resetAndSnapshot atomically returns current counts and resets the tracker.
|
||||
func (t *sendRateTracker) resetAndSnapshot() map[string]int64 {
|
||||
t.mu.Lock()
|
||||
snap := t.counts
|
||||
t.counts = make(map[string]int64, len(snap))
|
||||
t.mu.Unlock()
|
||||
return snap
|
||||
}
|
||||
|
||||
// logSendRates periodically logs peers in the top percentile of the busiest peer.
|
||||
func (t *sendRateTracker) logSendRates(ctx context.Context) {
|
||||
currentInterval := t.getInterval()
|
||||
ticker := time.NewTicker(currentInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if newInterval := t.getInterval(); newInterval != currentInterval {
|
||||
currentInterval = newInterval
|
||||
ticker.Reset(currentInterval)
|
||||
}
|
||||
|
||||
snap := t.resetAndSnapshot()
|
||||
if len(snap) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var maxCount int64
|
||||
for _, count := range snap {
|
||||
if count > maxCount {
|
||||
maxCount = count
|
||||
}
|
||||
}
|
||||
|
||||
topPercent := t.getTopPercent()
|
||||
threshold := int64(float64(maxCount) * topPercent)
|
||||
intervalMin := currentInterval.Minutes()
|
||||
|
||||
log.Debugf("send rate stats: %d unique peers in last %.0fs, max rate %.1f msg/min",
|
||||
len(snap), currentInterval.Seconds(), float64(maxCount)/intervalMin)
|
||||
logged := 0
|
||||
for key, count := range snap {
|
||||
if count >= threshold {
|
||||
log.Debugf("peer [%s] %.1f msg/min", key, float64(count)/intervalMin)
|
||||
logged++
|
||||
if logged >= 100 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
56
signal/server/send_tracker_test.go
Normal file
56
signal/server/send_tracker_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSendRateTracker_Increment(t *testing.T) {
|
||||
tracker := newSendRateTracker()
|
||||
|
||||
tracker.increment("peer-a")
|
||||
tracker.increment("peer-a")
|
||||
tracker.increment("peer-b")
|
||||
|
||||
snap := tracker.resetAndSnapshot()
|
||||
if snap["peer-a"] != 2 {
|
||||
t.Errorf("expected peer-a count 2, got %d", snap["peer-a"])
|
||||
}
|
||||
if snap["peer-b"] != 1 {
|
||||
t.Errorf("expected peer-b count 1, got %d", snap["peer-b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendRateTracker_ResetAndSnapshot_Resets(t *testing.T) {
|
||||
tracker := newSendRateTracker()
|
||||
tracker.increment("peer-a")
|
||||
|
||||
snap1 := tracker.resetAndSnapshot()
|
||||
if snap1["peer-a"] != 1 {
|
||||
t.Fatalf("expected 1, got %d", snap1["peer-a"])
|
||||
}
|
||||
|
||||
snap2 := tracker.resetAndSnapshot()
|
||||
if len(snap2) != 0 {
|
||||
t.Errorf("expected empty snapshot after reset, got %v", snap2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendRateTracker_ConcurrentIncrement(t *testing.T) {
|
||||
tracker := newSendRateTracker()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.increment("peer-x")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
snap := tracker.resetAndSnapshot()
|
||||
if snap["peer-x"] != 100 {
|
||||
t.Errorf("expected 100, got %d", snap["peer-x"])
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/settingoverrider"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/signal/peer"
|
||||
@@ -59,10 +61,12 @@ type Server struct {
|
||||
successHeader metadata.MD
|
||||
|
||||
sendTimeout time.Duration
|
||||
|
||||
sendTracker *sendRateTracker
|
||||
}
|
||||
|
||||
// NewServer creates a new Signal server
|
||||
func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) {
|
||||
func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverrider.Overrider, metricsPrefix ...string) (*Server, error) {
|
||||
appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||
@@ -80,14 +84,36 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string)
|
||||
sTimeout = parsed
|
||||
}
|
||||
|
||||
tracker := newSendRateTracker()
|
||||
|
||||
s := &Server{
|
||||
dispatcher: d,
|
||||
registry: peer.NewRegistry(appMetrics),
|
||||
metrics: appMetrics,
|
||||
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
|
||||
sendTimeout: sTimeout,
|
||||
sendTracker: tracker,
|
||||
}
|
||||
|
||||
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateLogInterval", func(value string) error {
|
||||
parsed, err := time.ParseDuration(value)
|
||||
if err != nil || parsed <= 0 {
|
||||
return fmt.Errorf("invalid send rate log interval %q: %w", value, err)
|
||||
}
|
||||
tracker.setInterval(parsed)
|
||||
return nil
|
||||
})
|
||||
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateTopPercent", func(value string) error {
|
||||
parsed, err := strconv.ParseFloat(value, 64)
|
||||
if err != nil || parsed <= 0 || parsed > 1 {
|
||||
return fmt.Errorf("invalid send rate top percent %q: %w", value, err)
|
||||
}
|
||||
tracker.setTopPercent(parsed)
|
||||
return nil
|
||||
})
|
||||
|
||||
go tracker.logSendRates(ctx)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -95,6 +121,8 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string)
|
||||
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
|
||||
|
||||
s.sendTracker.increment(msg.Key)
|
||||
|
||||
if _, found := s.registry.Get(msg.RemoteKey); found {
|
||||
s.forwardMessageToPeer(ctx, msg)
|
||||
return &proto.EncryptedMessage{}, nil
|
||||
|
||||
Reference in New Issue
Block a user