Compare commits

...

6 Commits

Author SHA1 Message Date
Viktor Liu
a6a8d139c7 Fix tests 2025-10-09 14:31:01 +02:00
Viktor Liu
c987cddc85 Fix rule squashing with unsymmetrical proto ALL rules 2025-10-09 14:18:36 +02:00
Zoltan Papp
4d33567888 [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228)
* When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder.
2025-10-08 03:12:16 +02:00
Viktor Liu
88467883fc [management,signal] Remove ws-proxy read deadline (#4598) 2025-10-06 22:05:48 +02:00
Viktor Liu
954f40991f [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) 2025-10-06 21:22:19 +02:00
Maycon Santos
34341d95a9 Adjust signal port for websocket connections (#4594) 2025-10-06 15:22:02 -03:00
16 changed files with 336 additions and 112 deletions

View File

@@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { // for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil { if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -37,8 +38,6 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
} }
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool, RootCAs: certPool,
})) }))
} }

View File

@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -21,4 +21,5 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
} }

View File

@@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()

View File

@@ -388,7 +388,8 @@ func (d *DefaultManager) squashAcceptRules(
// trace which type of protocols was squashed // trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{} squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} squashedProtocolsIn := map[mgmProto.RuleProtocol]struct{}{}
squashedProtocolsOut := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not. // this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list. // We summ amount of Peers IP for given protocol we found in original rules list.
@@ -397,7 +398,7 @@ func (d *DefaultManager) squashAcceptRules(
// 2. Any of rule contains Port. // 2. Any of rule contains Port.
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch, squashedProtocols map[mgmProto.RuleProtocol]struct{}) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo) r.Port != "" || !portInfoEmpty(r.PortInfo)
@@ -435,9 +436,9 @@ func (d *DefaultManager) squashAcceptRules(
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
// calculate squash for different directions // calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN { if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in) addRuleToCalculationMap(i, r, in, squashedProtocolsIn)
} else { } else {
addRuleToCalculationMap(i, r, out) addRuleToCalculationMap(i, r, out, squashedProtocolsOut)
} }
} }
@@ -450,7 +451,7 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.RuleProtocol_UDP, mgmProto.RuleProtocol_UDP,
} }
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection, squashedProtocols map[mgmProto.RuleProtocol]struct{}) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
match, ok := matches[protocol] match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
@@ -478,11 +479,22 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
squash(in, mgmProto.RuleDirection_IN) squash(in, mgmProto.RuleDirection_IN, squashedProtocolsIn)
squash(out, mgmProto.RuleDirection_OUT) squash(out, mgmProto.RuleDirection_OUT, squashedProtocolsOut)
// if all protocol was squashed everything is allow and we can ignore all other rules // if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { _, inSquashed := squashedProtocolsIn[mgmProto.RuleProtocol_ALL]
_, outSquashed := squashedProtocolsOut[mgmProto.RuleProtocol_ALL]
squashedProtocols := make(map[mgmProto.RuleProtocol]struct{})
for k := range squashedProtocolsIn {
squashedProtocols[k] = struct{}{}
}
for k := range squashedProtocolsOut {
squashedProtocols[k] = struct{}{}
}
if inSquashed && outSquashed {
return squashedRules, squashedProtocols return squashedRules, squashedProtocols
} }
@@ -494,10 +506,15 @@ func (d *DefaultManager) squashAcceptRules(
// filter out rules which was squashed from final list // filter out rules which was squashed from final list
// if we also have other not squashed rules. // if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
squashedProtocols := squashedProtocolsIn
protocols := in
if r.Direction == mgmProto.RuleDirection_OUT {
squashedProtocols = squashedProtocolsOut
protocols = out
}
if _, ok := squashedProtocols[r.Protocol]; ok { if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { if m, ok := protocols[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue continue
} }
} }

View File

@@ -758,6 +758,129 @@ func TestPortInfoEmpty(t *testing.T) {
} }
} }
func TestDefaultManagerSquashRulesMixed(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"100.66.152.160"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "100.66.152.160",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
var inRules, outRules []*mgmProto.FirewallRule
for _, r := range rules {
if r.Direction == mgmProto.RuleDirection_IN {
inRules = append(inRules, r)
} else {
outRules = append(outRules, r)
}
}
assert.Equal(t, 1, len(inRules))
assert.Equal(t, 1, len(outRules))
assert.Equal(t, "0.0.0.0", outRules[0].PeerIP)
assert.Equal(t, "100.66.152.160", inRules[0].PeerIP)
}
func TestDefaultManagerSquashRulesBothOptimized(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"100.66.152.160"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
var inRules, outRules []*mgmProto.FirewallRule
for _, r := range rules {
if r.Direction == mgmProto.RuleDirection_IN {
inRules = append(inRules, r)
} else {
outRules = append(outRules, r)
}
}
assert.Equal(t, 1, len(inRules))
assert.Equal(t, 1, len(outRules))
assert.Equal(t, "0.0.0.0", outRules[0].PeerIP)
assert.Equal(t, "0.0.0.0", inRules[0].PeerIP)
}
func TestDefaultManagerSquashRulesBothSpecific(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"100.66.152.160"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "100.66.152.160",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "100.66.152.160",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
var inRules, outRules []*mgmProto.FirewallRule
for _, r := range rules {
if r.Direction == mgmProto.RuleDirection_IN {
inRules = append(inRules, r)
} else {
outRules = append(outRules, r)
}
}
assert.Equal(t, 1, len(inRules))
assert.Equal(t, 1, len(outRules))
assert.Equal(t, "100.66.152.160", outRules[0].PeerIP)
assert.Equal(t, "100.66.152.160", inRules[0].PeerIP)
}
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{ PeerConfig: &mgmProto.PeerConfig{

View File

@@ -105,6 +105,10 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented") return nil, fmt.Errorf("not implemented")
} }

View File

@@ -28,6 +28,7 @@ type wgIfaceBase interface {
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error

View File

@@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() {
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {

View File

@@ -18,4 +18,5 @@ type WGIface interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
RemoveEndpointAddress(key string) error
} }

View File

@@ -47,7 +47,7 @@ services:
- traefik.enable=true - traefik.enable=true
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -621,7 +621,7 @@ renderCaddyfile() {
# relay # relay
reverse_proxy /relay* relay:80 reverse_proxy /relay* relay:80
# Signal # Signal
reverse_proxy /ws-proxy/signal* signal:10000 reverse_proxy /ws-proxy/signal* signal:80
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management # Management
reverse_proxy /api/* management:80 reverse_proxy /api/* management:80

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/netip"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
} }
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
switch { switch {

View File

@@ -10,7 +10,6 @@ import (
"net/http" "net/http"
// nolint:gosec // nolint:gosec
_ "net/http/pprof" _ "net/http/pprof"
"net/netip"
"time" "time"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -63,10 +62,10 @@ var (
Use: "run", Use: "run",
Short: "start NetBird Signal Server daemon", Short: "start NetBird Signal Server daemon",
SilenceUsage: true, SilenceUsage: true,
PreRun: func(cmd *cobra.Command, args []string) { PreRunE: func(cmd *cobra.Command, args []string) error {
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Fatalf("failed initializing log %v", err) return fmt.Errorf("failed initializing log: %w", err)
} }
flag.Parse() flag.Parse()
@@ -87,6 +86,8 @@ var (
signalPort = 80 signalPort = 80
} }
} }
return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse() flag.Parse()
@@ -254,7 +255,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
} }
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch { switch {

View File

@@ -2,42 +2,41 @@ package server
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/coder/websocket" "github.com/coder/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"github.com/netbirdio/netbird/util/wsproxy" "github.com/netbirdio/netbird/util/wsproxy"
) )
const ( const (
dialTimeout = 10 * time.Second
bufferSize = 32 * 1024 bufferSize = 32 * 1024
ioTimeout = 5 * time.Second
) )
// Config contains the configuration for the WebSocket proxy. // Config contains the configuration for the WebSocket proxy.
type Config struct { type Config struct {
LocalGRPCAddr netip.AddrPort Handler http.Handler
Path string Path string
MetricsRecorder MetricsRecorder MetricsRecorder MetricsRecorder
} }
// Proxy handles WebSocket to TCP proxying for gRPC connections. // Proxy handles WebSocket to gRPC handler proxying.
type Proxy struct { type Proxy struct {
config Config config Config
metrics MetricsRecorder metrics MetricsRecorder
} }
// New creates a new WebSocket proxy instance with optional configuration // New creates a new WebSocket proxy instance with optional configuration
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { func New(handler http.Handler, opts ...Option) *Proxy {
config := Config{ config := Config{
LocalGRPCAddr: localGRPCAddr, Handler: handler,
Path: wsproxy.ProxyPath, Path: wsproxy.ProxyPath,
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
} }
@@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
p.metrics.RecordConnection(ctx) p.metrics.RecordConnection(ctx)
defer p.metrics.RecordDisconnection(ctx) defer p.metrics.RecordDisconnection(ctx)
log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr)
acceptOptions := &websocket.AcceptOptions{ acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, OriginPatterns: []string{"*"},
} }
@@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return return
} }
defer func() { defer func() {
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { _ = wsConn.Close(websocket.StatusNormalClosure, "")
log.Debugf("Failed to close WebSocket: %v", err)
}
}() }()
log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) clientConn, serverConn := net.Pipe()
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
if err != nil {
p.metrics.RecordError(ctx, "tcp_dial_failed")
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
}
return
}
defer func() { defer func() {
if err := tcpConn.Close(); err != nil { _ = clientConn.Close()
log.Debugf("Failed to close TCP connection: %v", err) _ = serverConn.Close()
}
}() }()
log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr)
p.proxyData(ctx, wsConn, tcpConn) go func() {
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
Context: ctx,
Handler: p.config.Handler,
})
}()
p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
} }
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
proxyCtx, cancel := context.WithCancel(ctx) proxyCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
done := make(chan struct{})
go func() {
wg.Wait() wg.Wait()
close(done)
}()
select {
case <-done:
log.Tracef("Proxy data transfer completed, both goroutines terminated")
case <-proxyCtx.Done():
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
log.Tracef("Error closing WebSocket during cancellation: %v", err)
}
if err := tcpConn.Close(); err != nil {
log.Tracef("Error closing TCP connection during cancellation: %v", err)
} }
select { func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
case <-done:
log.Tracef("Goroutines terminated after forced connection closure")
case <-time.After(2 * time.Second):
log.Tracef("Goroutines did not terminate within timeout after connection closure")
}
}
}
func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
@@ -148,80 +117,73 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync
if err != nil { if err != nil {
switch { switch {
case ctx.Err() != nil: case ctx.Err() != nil:
log.Debugf("wsToTCP goroutine terminating due to context cancellation") log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr)
case websocket.CloseStatus(err) == websocket.StatusNormalClosure: case websocket.CloseStatus(err) != -1:
log.Debugf("WebSocket closed normally") log.Debugf("WebSocket from %s disconnected", clientAddr)
default: default:
p.metrics.RecordError(ctx, "websocket_read_error") p.metrics.RecordError(ctx, "websocket_read_error")
log.Errorf("WebSocket read error: %v", err) log.Debugf("WebSocket read error from %s: %v", clientAddr, err)
} }
return return
} }
if msgType != websocket.MessageBinary { if msgType != websocket.MessageBinary {
log.Warnf("Unexpected WebSocket message type: %v", msgType) log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType)
continue continue
} }
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write")
return return
} }
if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil {
log.Debugf("Failed to set TCP write deadline: %v", err) log.Debugf("Failed to set pipe write deadline: %v", err)
} }
n, err := tcpConn.Write(data) n, err := pipeConn.Write(data)
if err != nil { if err != nil {
p.metrics.RecordError(ctx, "tcp_write_error") p.metrics.RecordError(ctx, "pipe_write_error")
log.Errorf("TCP write error: %v", err) log.Warnf("Pipe write error for %s: %v", clientAddr, err)
return return
} }
p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n))
} }
} }
func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { n, err := pipeConn.Read(buf)
log.Debugf("Failed to set TCP read deadline: %v", err)
}
n, err := tcpConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("tcpToWS goroutine terminating due to context cancellation") log.Tracef("pipeToWS goroutine terminating due to context cancellation")
return return
} }
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue
}
if err != io.EOF { if err != io.EOF {
log.Errorf("TCP read error: %v", err) log.Debugf("Pipe read error for %s: %v", clientAddr, err)
} }
return return
} }
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write")
return return
} }
if n > 0 {
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
p.metrics.RecordError(ctx, "websocket_write_error") p.metrics.RecordError(ctx, "websocket_write_error")
log.Errorf("WebSocket write error: %v", err) log.Warnf("WebSocket write error for %s: %v", clientAddr, err)
return return
} }
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
}
} }
} }