Compare commits

...

7 Commits

Author SHA1 Message Date
Pascal Fischer
ca46fe215a use single latest message buf 2025-10-08 17:23:28 +02:00
Pascal Fischer
e5f926fa6d remove additional network map object in update message 2025-10-08 17:11:25 +02:00
hakansa
229c65ffa1 Enhance showLoginURL to include connection status check and auto-close functionality (#4525) 2025-10-08 12:42:15 +02:00
Zoltan Papp
4d33567888 [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228)
* When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder.
2025-10-08 03:12:16 +02:00
Viktor Liu
88467883fc [management,signal] Remove ws-proxy read deadline (#4598) 2025-10-06 22:05:48 +02:00
Viktor Liu
954f40991f [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) 2025-10-06 21:22:19 +02:00
Maycon Santos
34341d95a9 Adjust signal port for websocket connections (#4594) 2025-10-06 15:22:02 -03:00
19 changed files with 348 additions and 172 deletions

View File

@@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { // for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil { if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -37,8 +38,6 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
} }
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool, RootCAs: certPool,
})) }))
} }

View File

@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -21,4 +21,5 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
} }

View File

@@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()

View File

@@ -105,6 +105,10 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented") return nil, fmt.Errorf("not implemented")
} }

View File

@@ -28,6 +28,7 @@ type wgIfaceBase interface {
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error

View File

@@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() {
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {

View File

@@ -18,4 +18,5 @@ type WGIface interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
RemoveEndpointAddress(key string) error
} }

View File

@@ -1354,7 +1354,13 @@ func (s *serviceClient) updateConfig() error {
} }
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
func (s *serviceClient) showLoginURL() { // It also starts a background goroutine that periodically checks if the client is already connected
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
// also cancelled when the window is closed.
func (s *serviceClient) showLoginURL() context.CancelFunc {
// create a cancellable context for the background check goroutine
ctx, cancel := context.WithCancel(s.ctx)
resIcon := fyne.NewStaticResource("netbird.png", iconAbout) resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
@@ -1363,6 +1369,8 @@ func (s *serviceClient) showLoginURL() {
s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.Resize(fyne.NewSize(400, 200))
s.wLoginURL.SetIcon(resIcon) s.wLoginURL.SetIcon(resIcon)
} }
// ensure goroutine is cancelled when the window is closed
s.wLoginURL.SetOnClosed(func() { cancel() })
// add a description label // add a description label
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
@@ -1443,7 +1451,39 @@ func (s *serviceClient) showLoginURL() {
) )
s.wLoginURL.SetContent(container.NewCenter(content)) s.wLoginURL.SetContent(container.NewCenter(content))
// start a goroutine to check connection status and close the window if connected
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
continue
}
if status.Status == string(internal.StatusConnected) {
if s.wLoginURL != nil {
s.wLoginURL.Close()
}
return
}
}
}
}()
s.wLoginURL.Show() s.wLoginURL.Show()
// return cancel func so callers can stop the background goroutine if desired
return cancel
} }
func openURL(url string) error { func openURL(url string) error {

View File

@@ -47,7 +47,7 @@ services:
- traefik.enable=true - traefik.enable=true
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -621,7 +621,7 @@ renderCaddyfile() {
# relay # relay
reverse_proxy /relay* relay:80 reverse_proxy /relay* relay:80
# Signal # Signal
reverse_proxy /ws-proxy/signal* signal:10000 reverse_proxy /ws-proxy/signal* signal:80
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management # Management
reverse_proxy /api/* management:80 reverse_proxy /api/* management:80

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/netip"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
} }
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
switch { switch {

View File

@@ -86,7 +86,7 @@ func NewServer(
if appMetrics != nil { if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams // update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels)) return int64(peersUpdateManager.GetChannelCount())
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -1270,12 +1270,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer) }(peer)
} }
//
wg.Wait() wg.Wait()
if am.metrics != nil { if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart)) am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
@@ -1381,7 +1379,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
} }
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1603,7 +1601,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
}, },
}, },
}, },
NetworkMap: &types.NetworkMap{},
}) })
am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() { peerDeletedEvents = append(peerDeletedEvents, func() {

View File

@@ -1043,8 +1043,8 @@ func TestUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels { for _, channel := range peerChannels {
update := <-channel update := <-channel
assert.Nil(t, update.Update.NetbirdConfig) assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) // assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) // assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
} }
}) })
} }

View File

@@ -7,23 +7,25 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/proto"
) )
const channelBufferSize = 100
type UpdateMessage struct { type UpdateMessage struct {
Update *proto.SyncResponse Update *proto.SyncResponse
NetworkMap *types.NetworkMap }
type peerUpdate struct {
mu sync.Mutex
message *UpdateMessage
notify chan struct{}
} }
type PeersUpdateManager struct { type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID // latestUpdates stores the latest update message per peer
peerChannels map[string]chan *UpdateMessage latestUpdates sync.Map // map[string]*peerUpdate
// channelsMux keeps the mutex to access peerChannels // activePeers tracks which peers have active sender goroutines
channelsMux *sync.RWMutex activePeers sync.Map // map[string]struct{}
// metrics provides method to collect application metrics // metrics provides method to collect application metrics
metrics telemetry.AppMetrics metrics telemetry.AppMetrics
} }
@@ -31,87 +33,137 @@ type PeersUpdateManager struct {
// NewPeersUpdateManager returns a new instance of PeersUpdateManager // NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{ return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics, metrics: metrics,
} }
} }
// SendUpdate sends update message to the peer's channel // SendUpdate stores the latest update message for a peer and notifies the sender goroutine
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
start := time.Now() start := time.Now()
var found, dropped bool var found, dropped bool
p.channelsMux.RLock()
defer func() { defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil { if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped) p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
} }
}() }()
if channel, ok := p.peerChannels[peerID]; ok { // Check if peer has an active sender goroutine
found = true if _, ok := p.activePeers.Load(peerID); !ok {
select { log.WithContext(ctx).Debugf("peer %s has no active sender", peerID)
case channel <- update: return
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
} }
} else {
log.WithContext(ctx).Debugf("peer %s has no channel", peerID) found = true
// Load or create peerUpdate entry
val, _ := p.latestUpdates.LoadOrStore(peerID, &peerUpdate{
notify: make(chan struct{}, 1),
})
pu := val.(*peerUpdate)
// Store the latest message (overwrites any previous unsent message)
pu.mu.Lock()
pu.message = update
pu.mu.Unlock()
// Non-blocking notification
select {
case pu.notify <- struct{}{}:
log.WithContext(ctx).Debugf("update notification sent for peer %s", peerID)
default:
// Already notified, sender will pick up the latest message anyway
log.WithContext(ctx).Tracef("peer %s already notified, update will be picked up", peerID)
} }
} }
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. // CreateChannel creates a sender goroutine for a given peer and returns a channel to receive updates
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
start := time.Now() start := time.Now()
closed := false closed := false
p.channelsMux.Lock()
defer func() { defer func() {
p.channelsMux.Unlock()
if p.metrics != nil { if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed) p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed)
} }
}() }()
if channel, ok := p.peerChannels[peerID]; ok { // Close existing sender if any
if _, exists := p.activePeers.LoadOrStore(peerID, struct{}{}); exists {
closed = true closed = true
delete(p.peerChannels, peerID) p.closeChannel(ctx, peerID)
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
return channel
} }
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { // Create peerUpdate entry with notification channel
if channel, ok := p.peerChannels[peerID]; ok { pu := &peerUpdate{
delete(p.peerChannels, peerID) notify: make(chan struct{}, 1),
close(channel) }
p.latestUpdates.Store(peerID, pu)
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) // Create output channel for consumer
outChan := make(chan *UpdateMessage, 1)
// Start sender goroutine
go func() {
defer close(outChan)
for {
select {
case <-ctx.Done():
log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped due to context cancellation", peerID)
return
case <-pu.notify:
// Check if still active
if _, ok := p.activePeers.Load(peerID); !ok {
log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped", peerID)
return return
} }
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID) // Get the latest message with mutex protection
pu.mu.Lock()
msg := pu.message
pu.message = nil // Clear after reading
pu.mu.Unlock()
if msg != nil {
select {
case outChan <- msg:
log.WithContext(ctx).Tracef("sent update to peer %s", peerID)
case <-ctx.Done():
return
}
}
}
}
}()
log.WithContext(ctx).Debugf("created sender goroutine for peer %s", peerID)
return outChan
} }
// CloseChannels closes updates channel for each given peer func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
// Mark peer as inactive to stop the sender goroutine
if _, ok := p.activePeers.LoadAndDelete(peerID); ok {
// Close notification channel
if val, ok := p.latestUpdates.Load(peerID); ok {
pu := val.(*peerUpdate)
close(pu.notify)
}
p.latestUpdates.Delete(peerID)
log.WithContext(ctx).Debugf("closed sender for peer %s", peerID)
return
}
log.WithContext(ctx).Debugf("closing sender: peer %s has no active sender", peerID)
}
// CloseChannels closes sender goroutines for each given peer
func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) { func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) {
start := time.Now() start := time.Now()
p.channelsMux.Lock()
defer func() { defer func() {
p.channelsMux.Unlock()
if p.metrics != nil { if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs)) p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs))
} }
@@ -122,13 +174,11 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string
} }
} }
// CloseChannel closes updates channel of a given peer // CloseChannel closes the sender goroutine of a given peer
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
start := time.Now() start := time.Now()
p.channelsMux.Lock()
defer func() { defer func() {
p.channelsMux.Unlock()
if p.metrics != nil { if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start)) p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start))
} }
@@ -141,38 +191,43 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} { func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
start := time.Now() start := time.Now()
p.channelsMux.RLock()
m := make(map[string]struct{}) m := make(map[string]struct{})
defer func() { defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil { if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m)) p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
} }
}() }()
for ID := range p.peerChannels { p.activePeers.Range(func(key, value interface{}) bool {
m[ID] = struct{}{} m[key.(string)] = struct{}{}
} return true
})
return m return m
} }
// HasChannel returns true if peers has channel in update manager, otherwise false // HasChannel returns true if peer has an active sender goroutine, otherwise false
func (p *PeersUpdateManager) HasChannel(peerID string) bool { func (p *PeersUpdateManager) HasChannel(peerID string) bool {
start := time.Now() start := time.Now()
p.channelsMux.RLock()
defer func() { defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil { if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start)) p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start))
} }
}() }()
_, ok := p.peerChannels[peerID] _, ok := p.activePeers.Load(peerID)
return ok return ok
} }
// GetChannelCount returns the number of active peer channels
func (p *PeersUpdateManager) GetChannelCount() int {
count := 0
p.activePeers.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}

View File

@@ -10,7 +10,6 @@ import (
"net/http" "net/http"
// nolint:gosec // nolint:gosec
_ "net/http/pprof" _ "net/http/pprof"
"net/netip"
"time" "time"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -63,10 +62,10 @@ var (
Use: "run", Use: "run",
Short: "start NetBird Signal Server daemon", Short: "start NetBird Signal Server daemon",
SilenceUsage: true, SilenceUsage: true,
PreRun: func(cmd *cobra.Command, args []string) { PreRunE: func(cmd *cobra.Command, args []string) error {
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Fatalf("failed initializing log %v", err) return fmt.Errorf("failed initializing log: %w", err)
} }
flag.Parse() flag.Parse()
@@ -87,6 +86,8 @@ var (
signalPort = 80 signalPort = 80
} }
} }
return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse() flag.Parse()
@@ -254,7 +255,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
} }
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch { switch {

View File

@@ -2,42 +2,41 @@ package server
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/coder/websocket" "github.com/coder/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"github.com/netbirdio/netbird/util/wsproxy" "github.com/netbirdio/netbird/util/wsproxy"
) )
const ( const (
dialTimeout = 10 * time.Second
bufferSize = 32 * 1024 bufferSize = 32 * 1024
ioTimeout = 5 * time.Second
) )
// Config contains the configuration for the WebSocket proxy. // Config contains the configuration for the WebSocket proxy.
type Config struct { type Config struct {
LocalGRPCAddr netip.AddrPort Handler http.Handler
Path string Path string
MetricsRecorder MetricsRecorder MetricsRecorder MetricsRecorder
} }
// Proxy handles WebSocket to TCP proxying for gRPC connections. // Proxy handles WebSocket to gRPC handler proxying.
type Proxy struct { type Proxy struct {
config Config config Config
metrics MetricsRecorder metrics MetricsRecorder
} }
// New creates a new WebSocket proxy instance with optional configuration // New creates a new WebSocket proxy instance with optional configuration
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { func New(handler http.Handler, opts ...Option) *Proxy {
config := Config{ config := Config{
LocalGRPCAddr: localGRPCAddr, Handler: handler,
Path: wsproxy.ProxyPath, Path: wsproxy.ProxyPath,
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
} }
@@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
p.metrics.RecordConnection(ctx) p.metrics.RecordConnection(ctx)
defer p.metrics.RecordDisconnection(ctx) defer p.metrics.RecordDisconnection(ctx)
log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr)
acceptOptions := &websocket.AcceptOptions{ acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, OriginPatterns: []string{"*"},
} }
@@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return return
} }
defer func() { defer func() {
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { _ = wsConn.Close(websocket.StatusNormalClosure, "")
log.Debugf("Failed to close WebSocket: %v", err)
}
}() }()
log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) clientConn, serverConn := net.Pipe()
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
if err != nil {
p.metrics.RecordError(ctx, "tcp_dial_failed")
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
}
return
}
defer func() { defer func() {
if err := tcpConn.Close(); err != nil { _ = clientConn.Close()
log.Debugf("Failed to close TCP connection: %v", err) _ = serverConn.Close()
}
}() }()
log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr)
p.proxyData(ctx, wsConn, tcpConn) go func() {
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
Context: ctx,
Handler: p.config.Handler,
})
}()
p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
} }
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
proxyCtx, cancel := context.WithCancel(ctx) proxyCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
done := make(chan struct{})
go func() {
wg.Wait() wg.Wait()
close(done)
}()
select {
case <-done:
log.Tracef("Proxy data transfer completed, both goroutines terminated")
case <-proxyCtx.Done():
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
log.Tracef("Error closing WebSocket during cancellation: %v", err)
}
if err := tcpConn.Close(); err != nil {
log.Tracef("Error closing TCP connection during cancellation: %v", err)
} }
select { func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
case <-done:
log.Tracef("Goroutines terminated after forced connection closure")
case <-time.After(2 * time.Second):
log.Tracef("Goroutines did not terminate within timeout after connection closure")
}
}
}
func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
@@ -148,80 +117,73 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync
if err != nil { if err != nil {
switch { switch {
case ctx.Err() != nil: case ctx.Err() != nil:
log.Debugf("wsToTCP goroutine terminating due to context cancellation") log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr)
case websocket.CloseStatus(err) == websocket.StatusNormalClosure: case websocket.CloseStatus(err) != -1:
log.Debugf("WebSocket closed normally") log.Debugf("WebSocket from %s disconnected", clientAddr)
default: default:
p.metrics.RecordError(ctx, "websocket_read_error") p.metrics.RecordError(ctx, "websocket_read_error")
log.Errorf("WebSocket read error: %v", err) log.Debugf("WebSocket read error from %s: %v", clientAddr, err)
} }
return return
} }
if msgType != websocket.MessageBinary { if msgType != websocket.MessageBinary {
log.Warnf("Unexpected WebSocket message type: %v", msgType) log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType)
continue continue
} }
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write")
return return
} }
if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil {
log.Debugf("Failed to set TCP write deadline: %v", err) log.Debugf("Failed to set pipe write deadline: %v", err)
} }
n, err := tcpConn.Write(data) n, err := pipeConn.Write(data)
if err != nil { if err != nil {
p.metrics.RecordError(ctx, "tcp_write_error") p.metrics.RecordError(ctx, "pipe_write_error")
log.Errorf("TCP write error: %v", err) log.Warnf("Pipe write error for %s: %v", clientAddr, err)
return return
} }
p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n))
} }
} }
func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { n, err := pipeConn.Read(buf)
log.Debugf("Failed to set TCP read deadline: %v", err)
}
n, err := tcpConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("tcpToWS goroutine terminating due to context cancellation") log.Tracef("pipeToWS goroutine terminating due to context cancellation")
return return
} }
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue
}
if err != io.EOF { if err != io.EOF {
log.Errorf("TCP read error: %v", err) log.Debugf("Pipe read error for %s: %v", clientAddr, err)
} }
return return
} }
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write")
return return
} }
if n > 0 {
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
p.metrics.RecordError(ctx, "websocket_write_error") p.metrics.RecordError(ctx, "websocket_write_error")
log.Errorf("WebSocket write error: %v", err) log.Warnf("WebSocket write error for %s: %v", clientAddr, err)
return return
} }
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
}
} }
} }